Unverified Commit 770ec602 authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

[Model] Add support for the multi-modal Llama 3.2 model (#8811)


Co-authored-by: default avatarsimon-mo <xmo@berkeley.edu>
Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
Co-authored-by: default avatarSimon Mo <simon.mo@hey.com>
Co-authored-by: default avatarRoger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 4f1ba084
from transformers.models.mllama import configuration_mllama as mllama_hf_config
class MllamaTextConfig(mllama_hf_config.MllamaTextConfig):
'''
Use this class to override is_encoder_decoder:
- transformers regards mllama as is_encoder_decoder=False
- vllm needs is_encoder_decoder=True to enable cross-attention
'''
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self.is_encoder_decoder = True
class MllamaConfig(mllama_hf_config.MllamaConfig):
def __init__(
self,
text_config=None,
**kwargs,
):
if isinstance(text_config, dict):
text_config = MllamaTextConfig(**text_config)
super().__init__(text_config=text_config, **kwargs)
......@@ -111,7 +111,6 @@ def get_tokenizer(
'encoding and decoding.',
FutureWarning,
stacklevel=2)
if tokenizer_mode == "mistral":
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
revision=revision)
......
......@@ -18,7 +18,8 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput,
SequenceGroupMetadata)
......@@ -52,6 +53,7 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
......@@ -194,6 +196,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {}
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
......@@ -202,6 +206,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
logits = self.model.compute_logits(hidden_or_intermediate_states,
......@@ -288,8 +294,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
raise NotImplementedError(
"Multi-modal encoder-decoder models are not supported yet")
logger.info("Starting profile run for multi-modal models.")
batch_size = 0
for group_id in range(max_num_seqs):
......@@ -297,24 +302,39 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, _ = self.input_registry \
.dummy_data_for_profiling(self.model_config,
decoder_seq_data, decoder_dummy_multi_modal_data \
= self.input_registry.dummy_data_for_profiling(
self.model_config,
seq_len,
self.mm_registry)
self.mm_registry,
is_encoder_data=False)
encoder_seq_data, encoder_dummy_multi_modal_data \
= self.input_registry.dummy_data_for_profiling(
self.model_config,
seq_len,
self.mm_registry,
is_encoder_data=True)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
assert len(decoder_seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
f"but got: {len(decoder_seq_data.prompt_token_ids)}")
assert decoder_dummy_multi_modal_data is None or \
encoder_dummy_multi_modal_data is None, (
"Multi-modal data can't be provided in both encoder and decoder"
)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
seq_data={group_id: decoder_seq_data},
sampling_params=sampling_params,
block_tables=None,
encoder_seq_data=seq_data,
encoder_seq_data=encoder_seq_data,
cross_block_table=None,
multi_modal_data=decoder_dummy_multi_modal_data
or encoder_dummy_multi_modal_data,
)
seqs.append(seq)
......
......@@ -39,10 +39,6 @@ def assert_enc_dec_mr_supported_scenario(
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
if enc_dec_mr.model_config.is_multimodal_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment