Unverified Commit f1c66454 authored by Tugsbayasgalan Manlaibaatar's avatar Tugsbayasgalan Manlaibaatar Committed by GitHub
Browse files

Make voxtral compile friendly (#33959)


Signed-off-by: default avatarTugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Co-authored-by: default avatarNicolò Lucchesi <nlucches@redhat.com>
parent c870eb9e
...@@ -41,6 +41,7 @@ from vllm.multimodal.processing.processor import ( ...@@ -41,6 +41,7 @@ from vllm.multimodal.processing.processor import (
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
from vllm.utils.torch_utils import is_torch_equal_or_newer
from .utils import ( from .utils import (
_flatten_embeddings, _flatten_embeddings,
...@@ -337,9 +338,21 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim ...@@ -337,9 +338,21 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
assert input_ids is not None assert input_ids is not None
pool_size = self.config.audio_config.block_pool_size pool_size = self.config.audio_config.block_pool_size
inputs_embeds = inputs_embeds.view( if is_torch_equal_or_newer("2.11"):
inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size inputs_embeds = inputs_embeds.view(
) inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
)
else:
# TODO Use reshape + clone to break the view chain and avoid output
# aliasing input bug in torch.compile's AOT autograd cache.
# Without clone(), if any downstream operation returns a view that's
# connected to this view of inputs_embeds, the AOT autograd cache
# fails to pickle the ViewMetaSequence containing SymInt shapes.
# This will be fixed in pytorch 2.11 and beyond.
# issue: https://github.com/pytorch/pytorch/issues/174299
inputs_embeds = inputs_embeds.reshape(
inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
).clone()
whisper_positions = _expand_tensor(positions, pool_size) whisper_positions = _expand_tensor(positions, pool_size)
audio_hidden_states = self.whisper_encoder.whisper_encoder( audio_hidden_states = self.whisper_encoder.whisper_encoder(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment