"docs/vscode:/vscode.git/clone" did not exist on "371f7e4ca2a44fbd4a63cd641efb279274a717f4"
Unverified Commit 37e8182b authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[v1] Add Whisper model support (encoder-decoder) (#21088)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarNickLucche <nlucches@redhat.com>
parent 4db44264
...@@ -22,12 +22,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -22,12 +22,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec) super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.kv_cache_spec = kv_cache_spec
self.device = device
self.vllm_config = vllm_config
self.layer_names = layer_names
assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = min( self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs, self.vllm_config.scheduler_config.max_num_seqs,
...@@ -52,4 +49,4 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -52,4 +49,4 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
m.max_query_len = 1 # decode-only m.max_query_len = 1 # decode-only
return self.build(0, m) return self.build(0, m)
\ No newline at end of file
...@@ -236,11 +236,11 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -236,11 +236,11 @@ class AiterFlashAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads( self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config) self.parallel_config)
...@@ -248,7 +248,6 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -248,7 +248,6 @@ class AiterFlashAttentionMetadataBuilder(
self.parallel_config) self.parallel_config)
self.headdim = self.model_config.get_head_size() self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
# Sliding window size to be used with the AOT scheduler will be # Sliding window size to be used with the AOT scheduler will be
# populated on first build() call. # populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None self.aot_sliding_window: Optional[tuple[int, int]] = None
......
...@@ -45,8 +45,8 @@ class ShortConvAttentionMetadataBuilder( ...@@ -45,8 +45,8 @@ class ShortConvAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec) assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
def build(self, def build(self,
common_prefix_len: int, common_prefix_len: int,
......
...@@ -165,7 +165,8 @@ class TreeAttentionMetadataBuilder( ...@@ -165,7 +165,8 @@ class TreeAttentionMetadataBuilder(
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
): ):
self.kv_cache_spec = kv_cache_spec super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
spec_config = vllm_config.speculative_config spec_config = vllm_config.speculative_config
......
...@@ -66,9 +66,9 @@ class TritonAttentionMetadataBuilder( ...@@ -66,9 +66,9 @@ class TritonAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.device = device super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
model_config = vllm_config.model_config model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads( self.num_heads_q = model_config.get_num_attention_heads(
......
...@@ -72,6 +72,9 @@ class CommonAttentionMetadata: ...@@ -72,6 +72,9 @@ class CommonAttentionMetadata:
logits_indices_padded: Optional[torch.Tensor] = None logits_indices_padded: Optional[torch.Tensor] = None
num_logits_indices: Optional[int] = None num_logits_indices: Optional[int] = None
# Needed by CrossAttentionBuilder
encoder_seq_lens: Optional[np.ndarray] = None
@dataclass @dataclass
class UbatchSlice: class UbatchSlice:
...@@ -193,6 +196,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -193,6 +196,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.layer_names = layer_names
self.vllm_config = vllm_config
self.device = device
@abstractmethod @abstractmethod
def build(self, def build(self,
......
...@@ -206,8 +206,9 @@ class XFormersAttentionMetadataBuilder( ...@@ -206,8 +206,9 @@ class XFormersAttentionMetadataBuilder(
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
): ):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert XFORMERS_AVAILABLE assert XFORMERS_AVAILABLE
self.kv_cache_spec = kv_cache_spec
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self._num_decodes = 0 self._num_decodes = 0
self._num_decode_tokens = 0 self._num_decode_tokens = 0
......
...@@ -144,8 +144,8 @@ class Scheduler(SchedulerInterface): ...@@ -144,8 +144,8 @@ class Scheduler(SchedulerInterface):
) )
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and # NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also # projector if needed) for MM models as well as encoder-decoder
# has the Transformer architecture (e.g., ViT). # transformers.
self.max_num_encoder_input_tokens = encoder_compute_budget self.max_num_encoder_input_tokens = encoder_compute_budget
# NOTE: For the models without encoder (e.g., text-only models), # NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0 # the encoder cache will not be initialized because cache size is 0
...@@ -775,15 +775,19 @@ class Scheduler(SchedulerInterface): ...@@ -775,15 +775,19 @@ class Scheduler(SchedulerInterface):
# in the decoder's KV cache. # in the decoder's KV cache.
continue continue
# The same encoder input has already been scheduled in the current if not self.is_encoder_decoder:
# step. # We are not using the encoder cache for encoder-decoder models,
if request.mm_hashes[i] in mm_hashes_to_schedule: # yet.
continue if request.mm_hashes[i] in mm_hashes_to_schedule:
# The same encoder input has already been scheduled in the
# current step.
continue
if self.encoder_cache_manager.check_and_update_cache(request, i): if self.encoder_cache_manager.check_and_update_cache(
# The encoder input is already computed and cached from a request, i):
# previous step. # The encoder input is already computed and cached from a
continue # previous step.
continue
# If no encoder input chunking is allowed, we do not want to # If no encoder input chunking is allowed, we do not want to
# partially schedule a multimodal item. If the scheduled range would # partially schedule a multimodal item. If the scheduled range would
...@@ -1047,7 +1051,13 @@ class Scheduler(SchedulerInterface): ...@@ -1047,7 +1051,13 @@ class Scheduler(SchedulerInterface):
mm_positions = request.mm_positions[input_id] mm_positions = request.mm_positions[input_id]
start_pos = mm_positions.offset start_pos = mm_positions.offset
num_tokens = mm_positions.length num_tokens = mm_positions.length
if start_pos + num_tokens <= request.num_computed_tokens: if self.is_encoder_decoder and request.num_computed_tokens > 0:
# With Whisper, as soon as we've generated a single token,
# we know we're done with the encoder input. Cross Attention
# KVs have been calculated and cached already.
self.encoder_cache_manager.free_encoder_input(
request, input_id)
elif start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored # The encoder output is already processed and stored
# in the decoder's KV cache. # in the decoder's KV cache.
self.encoder_cache_manager.free_encoder_input( self.encoder_cache_manager.free_encoder_input(
......
...@@ -325,7 +325,6 @@ class Processor: ...@@ -325,7 +325,6 @@ class Processor:
) -> tuple[Optional[str], EngineCoreRequest]: ) -> tuple[Optional[str], EngineCoreRequest]:
# TODO(woosuk): Support pooling models. # TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request) self._validate_lora(lora_request)
self._validate_params(params, lora_request) self._validate_params(params, lora_request)
if trace_headers is not None: if trace_headers is not None:
...@@ -384,10 +383,6 @@ class Processor: ...@@ -384,10 +383,6 @@ class Processor:
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
# TODO: Impl encoder-decoder
if encoder_inputs is not None:
raise NotImplementedError
sampling_params = None sampling_params = None
pooling_params = None pooling_params = None
if isinstance(params, SamplingParams): if isinstance(params, SamplingParams):
......
...@@ -61,12 +61,16 @@ from vllm.v1.attention.backends.utils import ( ...@@ -61,12 +61,16 @@ from vllm.v1.attention.backends.utils import (
create_fast_prefill_custom_backend, create_fast_prefill_custom_backend,
reorder_batch_to_split_decodes_and_prefills) reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block
# yapf: disable
from vllm.v1.kv_cache_interface import (AttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec, ChunkedLocalAttentionSpec,
CrossAttentionSpec,
EncoderOnlyAttentionSpec, EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig, FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec, KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec) MambaSpec, SlidingWindowSpec)
# yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsLists, LogprobsTensors, DraftTokenIds, LogprobsLists, LogprobsTensors,
ModelRunnerOutput, SamplerOutput) ModelRunnerOutput, SamplerOutput)
...@@ -208,6 +212,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -208,6 +212,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config) model_config)
if self.model_config.is_encoder_decoder:
# Maximum length of the encoder input, only for encoder-decoder
# models.
self.max_encoder_len = self.mm_registry.\
get_encdec_max_encoder_len(model_config)
else:
self.max_encoder_len = 0
# Sampler # Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
...@@ -265,7 +277,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -265,7 +277,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# the block_sizes in the kv cache config. # the block_sizes in the kv cache config.
self.input_batch = InputBatch( self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len, # We need to use the encoder length for encoder-decoer
# because of KV cache for cross-attention.
max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens, max_num_batched_tokens=self.max_num_tokens,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
...@@ -798,6 +812,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -798,6 +812,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
src=self.input_batch.prev_sampled_token_ids[ src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0]) prev_common_req_indices_tensor, 0])
def _get_encoder_seq_lens(
self,
scheduler_output: "SchedulerOutput",
kv_cache_spec: KVCacheSpec,
num_reqs: int,
) -> Optional[np.ndarray]:
if not isinstance(kv_cache_spec, CrossAttentionSpec):
return None
# Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
for req_id in scheduler_output.scheduled_encoder_inputs:
req_index = self.input_batch.req_id_to_index[req_id]
encoder_seq_lens[req_index] = self.max_encoder_len
return encoder_seq_lens
def _prepare_inputs( def _prepare_inputs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
...@@ -937,6 +969,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -937,6 +969,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same group share the same metadata. # in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate( for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups): self.kv_cache_config.kv_cache_groups):
encoder_seq_lens = self._get_encoder_seq_lens(
scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs)
if isinstance(kv_cache_group_spec.kv_cache_spec, if isinstance(kv_cache_group_spec.kv_cache_spec,
EncoderOnlyAttentionSpec): EncoderOnlyAttentionSpec):
...@@ -981,6 +1015,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -981,6 +1015,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits_indices_padded=logits_indices_padded, logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0), num_logits_indices=logits_indices.size(0),
causal=True, causal=True,
encoder_seq_lens=encoder_seq_lens,
) )
if self.speculative_config and \ if self.speculative_config and \
...@@ -1253,10 +1288,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1253,10 +1288,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded])
return logits_indices_padded return logits_indices_padded
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): def _batch_mm_kwargs_from_scheduler(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]:
"""Batch multimodal kwargs from scheduled encoder inputs.
Args:
scheduler_output: The scheduler output containing scheduled encoder
inputs.
Returns:
A tuple of (mm_kwargs, req_ids_pos) where:
- mm_kwargs: List of multimodal kwargs items to be batched
- mm_hashes_pos: List of (mm_hash, position_info) tuples
"""
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: if not scheduled_encoder_inputs:
return return [], []
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]() mm_kwargs = list[MultiModalKwargsItem]()
# list of tuple (mm_hash, position_info) # list of tuple (mm_hash, position_info)
...@@ -1270,6 +1319,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1270,6 +1319,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_hashes_pos.append( mm_hashes_pos.append(
(mm_hash, req_state.mm_positions[mm_input_id])) (mm_hash, req_state.mm_positions[mm_input_id]))
return mm_kwargs, mm_hashes_pos
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
# Batch the multi-modal inputs using the helper method.
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
scheduler_output)
if not mm_kwargs:
return
# Batch mm inputs as much as we can: if a request in the batch has # Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one, # multiple modalities or a different modality than the previous one,
# we process it separately to preserve item order. # we process it separately to preserve item order.
...@@ -1360,6 +1419,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1360,6 +1419,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds.append(mm_embeds_item) mm_embeds.append(mm_embeds_item)
return mm_embeds return mm_embeds
def _extract_encoder_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> dict[str, torch.Tensor]:
"""Extract encoder inputs for encoder-decoder models.
This method extracts multimodal input features from scheduled encoder
inputs and formats them for the encoder-decoder model forward pass.
"""
# Batch the multi-modal inputs using the helper method.
mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
if not mm_kwargs:
return {}
# Group MM kwargs by modality and extract features
encoder_features = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
# Add the grouped features to encoder_features dict
# This allows the model to receive them as kwargs (e.g.,
# input_features=...)
encoder_features.update(mm_kwargs_group)
return encoder_features
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
# get raw model out of the cudagraph wrapper. # get raw model out of the cudagraph wrapper.
if isinstance(self.model, CUDAGraphWrapper): if isinstance(self.model, CUDAGraphWrapper):
...@@ -1631,7 +1719,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1631,7 +1719,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# _prepare_inputs may reorder the batch, so we must gather multi # _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order # modal outputs after that to ensure the correct order
if self.supports_mm_inputs and get_pp_group().is_first_rank: if (self.supports_mm_inputs and get_pp_group().is_first_rank
and not self.model_config.is_encoder_decoder):
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output)
...@@ -1673,6 +1762,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1673,6 +1762,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True) num_input_tokens, intermediate_tensors, True)
if (self.model_config.is_encoder_decoder
and scheduler_output.scheduled_encoder_inputs):
encoder_inputs = self._extract_encoder_inputs(scheduler_output)
model_kwargs.update(encoder_inputs)
return ( return (
num_scheduled_tokens, num_scheduled_tokens,
num_input_tokens, num_input_tokens,
...@@ -2591,17 +2685,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2591,17 +2685,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens, remove_lora): num_scheduled_tokens, remove_lora):
if self.supports_mm_inputs: model_kwargs = self._init_model_kwargs(num_tokens)
if (self.supports_mm_inputs
and not self.model_config.is_encoder_decoder):
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens] inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
model_kwargs = { model_kwargs = {
**self._init_model_kwargs(num_tokens), **model_kwargs,
**self._dummy_mm_kwargs(num_reqs), **self._dummy_mm_kwargs(num_reqs),
} }
else: else:
input_ids = self.input_ids.gpu[:num_tokens] input_ids = self.input_ids.gpu[:num_tokens]
inputs_embeds = None inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_tokens)
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_tokens] positions = self.mrope_positions.gpu[:, :num_tokens]
...@@ -2823,7 +2918,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2823,7 +2918,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_budget = self.mm_budget mm_budget = self.mm_budget
assert mm_budget is not None assert mm_budget is not None
# TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0: if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text # NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when # modality with the max possible input tokens even when
...@@ -3170,7 +3264,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3170,7 +3264,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"for more details.") "for more details.")
self.input_batch = InputBatch( self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len, max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens, max_num_batched_tokens=self.max_num_tokens,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
...@@ -3443,7 +3537,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3443,7 +3537,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items(): for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY: if attn_module.attn_type == AttentionType.ENCODER_ONLY:
attn_spec = EncoderOnlyAttentionSpec( attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
...@@ -3485,7 +3579,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3485,7 +3579,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue continue
# TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like # TODO(lucas): move the attention specs into the model layers like
# the attention backends # the attention backends
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
...@@ -3513,12 +3606,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3513,12 +3606,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
use_mla=use_mla) use_mla=use_mla)
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
kv_cache_spec[layer_name] = CrossAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache. # encoder-only attention does not need KV cache.
continue continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else: else:
raise ValueError( raise ValueError(
f"Unknown attention type: {attn_module.attn_type}") f"Unknown attention type: {attn_module.attn_type}")
......
...@@ -12,6 +12,7 @@ from vllm.model_executor.models.interfaces import MultiModalEmbeddings ...@@ -12,6 +12,7 @@ from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec
...@@ -269,7 +270,17 @@ def bind_kv_cache( ...@@ -269,7 +270,17 @@ def bind_kv_cache(
# One typical case is encoder-decoder model, e.g., bart. # One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer # The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index. # has different layer_name but the same layer_index.
raise NotImplementedError
# TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers
# in the same decoder block.
if current_platform.is_cuda():
# We know that the GPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this.
pass
else:
raise NotImplementedError
layer_name = layer_names[0] layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name]) runner_kv_caches.append(kv_caches[layer_name])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment