Unverified Commit 0b8bb86b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[1/N] Initial prototype for multi-modal processor (#10044)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent bb7991aa
import enum import enum
from typing import TYPE_CHECKING, List, Optional, Union from typing import List, Optional, Union
from vllm.inputs.data import DecoderOnlyInputs from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -9,23 +9,20 @@ from vllm.sequence import RequestMetrics ...@@ -9,23 +9,20 @@ from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.utils import ConstantList from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.inputs import DecoderOnlyInputs
class Request: class Request:
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
inputs: "DecoderOnlyInputs", inputs: DecoderOnlyInputs,
sampling_params: SamplingParams, sampling_params: SamplingParams,
eos_token_id: Optional[int], eos_token_id: Optional[int],
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.inputs = inputs self.inputs = SingletonInputsAdapter(inputs)
self.sampling_params = sampling_params self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request. # Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
...@@ -41,17 +38,17 @@ class Request: ...@@ -41,17 +38,17 @@ class Request:
assert sampling_params.max_tokens is not None assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens self.max_tokens = sampling_params.max_tokens
self.prompt = inputs.get("prompt") self.prompt = self.inputs.prompt
self.prompt_token_ids = inputs["prompt_token_ids"] self.prompt_token_ids = self.inputs.prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids) self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: List[int] = [] self._output_token_ids: List[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy() self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0 self.num_computed_tokens = 0
# Raw multimodal data before the mm input mapper (e.g., PIL images). # Raw multimodal data before the mm input mapper (e.g., PIL images).
self.mm_data = inputs.get("multi_modal_data") self.mm_data = self.inputs.multi_modal_data
self.mm_processor_kwargs = inputs.get("mm_processor_kwargs") self.mm_processor_kwargs = self.inputs.mm_processor_kwargs
mm_positions = inputs.get("multi_modal_placeholders") mm_positions = self.inputs.multi_modal_placeholders
if mm_positions: if mm_positions:
# FIXME(woosuk): Support other modalities. # FIXME(woosuk): Support other modalities.
self.mm_positions = mm_positions.get("image", []) self.mm_positions = mm_positions.get("image", [])
...@@ -64,8 +61,7 @@ class Request: ...@@ -64,8 +61,7 @@ class Request:
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
inputs=DecoderOnlyInputs( inputs=token_inputs(
type="token",
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt, prompt=request.prompt,
multi_modal_data=request.mm_data, multi_modal_data=request.mm_data,
...@@ -114,7 +110,7 @@ class Request: ...@@ -114,7 +110,7 @@ class Request:
return RequestStatus.get_finished_reason(self.status) return RequestStatus.get_finished_reason(self.status)
def has_encoder_inputs(self) -> bool: def has_encoder_inputs(self) -> bool:
return self.mm_data is not None return len(self.mm_data) > 0
@property @property
def num_encoder_inputs(self) -> int: def num_encoder_inputs(self) -> int:
......
...@@ -28,7 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput ...@@ -28,7 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.base import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -148,19 +148,29 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -148,19 +148,29 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
query_lens=seq_lens, query_lens=seq_lens,
) )
def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata, def _compute_multi_modal_input(
seq_data: SequenceData, computed_len: int, self,
mm_processor_kwargs: Dict[str, Any]): seq_data: SequenceData,
computed_len: int,
seq_group_metadata: SequenceGroupMetadata,
):
# NOTE: mm_data only includes the subset of multi-modal items that # NOTE: mm_data only includes the subset of multi-modal items that
# intersect with the current prefill positions. # intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group, range(computed_len, len(seq_data.get_token_ids()))) seq_group_metadata,
range(computed_len, len(seq_data.get_token_ids())),
)
if not mm_data: if not mm_data:
return return None, None, None
mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs) if self.runner.mm_registry.has_processor(self.runner.model_config):
mm_kwargs = mm_data
else:
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
seq_group_metadata.mm_processor_kwargs,
)
# special processing for mrope position deltas. # special processing for mrope position deltas.
mrope_positions = None mrope_positions = None
...@@ -202,7 +212,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -202,7 +212,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
slot_mapping: List[int] = [] slot_mapping: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_model_kwargs_list: List[MultiModalKwargs] = [] multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[ multi_modal_placeholder_maps: Dict[
str, str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
...@@ -223,11 +233,14 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -223,11 +233,14 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
mrope_positions = None mrope_positions = None
if seq_group_metadata.multi_modal_data: if seq_group_metadata.multi_modal_data:
mm_kwargs, placeholder_maps, mrope_positions = self \ (
._compute_multi_modal_input( mm_kwargs,
seq_group_metadata, seq_data, computed_len, placeholder_maps,
seq_group_metadata.mm_processor_kwargs) mrope_positions,
multi_model_kwargs_list.append(mm_kwargs) ) = self._compute_multi_modal_input(seq_data, computed_len,
seq_group_metadata)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items(): for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend( multi_modal_placeholder_maps[modality].extend(
placeholder_map) placeholder_map)
...@@ -302,7 +315,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -302,7 +315,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
) )
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens, return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs) multi_modal_kwargs)
......
...@@ -716,7 +716,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -716,7 +716,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
context_lens: List[int] = [] context_lens: List[int] = []
query_lens: List[int] = [] query_lens: List[int] = []
prefix_block_tables: List[List[int]] = [] prefix_block_tables: List[List[int]] = []
multi_model_kwargs_list: List[MultiModalKwargs] = [] multi_modal_kwargs_list: List[MultiModalKwargs] = []
if len(seq_group_metadata_list) == 0: if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty() return PreparePromptMetadata.empty()
...@@ -777,7 +777,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -777,7 +777,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
mm_data = seq_group_metadata.multi_modal_data mm_data = seq_group_metadata.multi_modal_data
if mm_data: if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data) mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_model_kwargs_list.append(mm_kwargs) multi_modal_kwargs_list.append(mm_kwargs)
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
...@@ -876,7 +876,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -876,7 +876,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
multi_modal_placeholder_index_maps= multi_modal_placeholder_index_maps=
None # FIXME(kzawora): mutli-modality will not work here None # FIXME(kzawora): mutli-modality will not work here
) )
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return PreparePromptMetadata(input_tokens=input_tokens, return PreparePromptMetadata(input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,
......
...@@ -252,7 +252,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -252,7 +252,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
# Multi-modal inputs. # Multi-modal inputs.
multi_model_kwargs: Optional[MultiModalKwargs] = None, multi_modal_kwargs: Optional[MultiModalKwargs] = None,
multi_modal_placeholder_maps: Optional[Dict[ multi_modal_placeholder_maps: Optional[Dict[
str, MultiModalPlaceholderMap]] = None, str, MultiModalPlaceholderMap]] = None,
...@@ -373,7 +373,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -373,7 +373,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_prompt_mapping or []) prompt_adapter_prompt_mapping or [])
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.multi_model_kwargs = multi_model_kwargs self.multi_modal_kwargs = multi_modal_kwargs
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
self.prefix_cache_hit = prefix_cache_hit self.prefix_cache_hit = prefix_cache_hit
...@@ -661,10 +661,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -661,10 +661,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if not mm_data: if not mm_data:
return return
mm_kwargs = self.multi_modal_input_mapper( if self.runner.mm_registry.has_processor(self.runner.model_config):
mm_data, mm_kwargs = mm_data
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs) else:
inter_data.multi_model_kwargs = mm_kwargs mm_kwargs = self.multi_modal_input_mapper(
mm_data,
seq_group_metadata.mm_processor_kwargs,
)
inter_data.multi_modal_kwargs = mm_kwargs
inter_data.multi_modal_placeholder_maps = placeholder_maps inter_data.multi_modal_placeholder_maps = placeholder_maps
# special processing for mrope position deltas. # special processing for mrope position deltas.
...@@ -938,11 +943,11 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -938,11 +943,11 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
) )
# Multi-modal data. # Multi-modal data.
multi_model_kwargs_list = [ multi_modal_kwargs_list = [
data.multi_model_kwargs for data in self.inter_data_list data.multi_modal_kwargs for data in self.inter_data_list
if data.multi_model_kwargs is not None if data.multi_modal_kwargs is not None
] ]
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return self.model_input_cls( return self.model_input_cls(
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
......
...@@ -67,7 +67,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -67,7 +67,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
# Multi-modal data support # Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ self.mm_registry = MULTIMODAL_REGISTRY
self.multi_modal_input_mapper = self.mm_registry \
.create_input_mapper(self.model_config) .create_input_mapper(self.model_config)
# Lazy initialization. # Lazy initialization.
...@@ -122,7 +123,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -122,7 +123,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_block_ids: List[int] = [] input_block_ids: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_model_kwargs_list: List[MultiModalKwargs] = [] multi_modal_kwargs_list: List[MultiModalKwargs] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
...@@ -144,12 +145,15 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -144,12 +145,15 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
mm_data = seq_group_metadata.multi_modal_data mm_data = seq_group_metadata.multi_modal_data
if mm_data: if mm_data:
# Process multi-modal data if self.mm_registry.has_processor(self.model_config):
mm_kwargs = self.multi_modal_input_mapper( mm_kwargs = mm_data
mm_data, else:
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs, mm_kwargs = self.multi_modal_input_mapper(
) mm_data,
multi_model_kwargs_list.append(mm_kwargs) seq_group_metadata.mm_processor_kwargs,
)
multi_modal_kwargs_list.append(mm_kwargs)
max_seq_len = max(seq_lens) max_seq_len = max(seq_lens)
assert max_seq_len > 0 assert max_seq_len > 0
...@@ -167,7 +171,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -167,7 +171,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return (input_tokens, input_positions, input_block_ids, seq_lens, return (input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs) multi_modal_kwargs)
......
...@@ -70,7 +70,8 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -70,7 +70,8 @@ class OpenVINOModelRunner(ModelRunnerBase):
) )
# Multi-modal data support # Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ self.mm_registry = MULTIMODAL_REGISTRY
self.multi_modal_input_mapper = self.mm_registry \
.create_input_mapper(self.model_config) .create_input_mapper(self.model_config)
# Lazy initialization. # Lazy initialization.
...@@ -102,7 +103,7 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -102,7 +103,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
seq_lens: List[int] = [] seq_lens: List[int] = []
past_lens: List[int] = [] past_lens: List[int] = []
query_lens: List[int] = [] query_lens: List[int] = []
multi_model_kwargs_list: List[MultiModalKwargs] = [] multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[ multi_modal_placeholder_maps: Dict[
str, str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
...@@ -222,11 +223,15 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -222,11 +223,15 @@ class OpenVINOModelRunner(ModelRunnerBase):
mm_data, placeholder_maps = MultiModalPlaceholderMap \ mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range) .from_seq_group(seq_group_metadata, positions_range)
mm_kwargs = self.multi_modal_input_mapper( if self.mm_registry.has_processor(self.model_config):
mm_data, mm_kwargs = mm_data
mm_processor_kwargs=seq_group_metadata. else:
mm_processor_kwargs) mm_kwargs = self.multi_modal_input_mapper(
multi_model_kwargs_list.append(mm_kwargs) mm_data,
seq_group_metadata.mm_processor_kwargs,
)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items(): for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend( multi_modal_placeholder_maps[modality].extend(
...@@ -275,7 +280,7 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -275,7 +280,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
) )
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return ModelInput( return ModelInput(
input_tokens, input_tokens,
......
...@@ -160,7 +160,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -160,7 +160,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_model_kwargs_list: List[MultiModalKwargs] = [] multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[ multi_modal_placeholder_maps: Dict[
str, str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
...@@ -191,8 +191,16 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -191,8 +191,16 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
mm_data, placeholder_maps = MultiModalPlaceholderMap \ mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range) .from_seq_group(seq_group_metadata, positions_range)
mm_kwargs = self.runner.multi_modal_input_mapper(mm_data) if self.runner.mm_registry.has_processor(
multi_model_kwargs_list.append(mm_kwargs) self.runner.model_config):
mm_kwargs = mm_data
else:
mm_kwargs = self.runner.multi_modal_input_mapper(
mm_data,
seq_group_metadata.mm_processor_kwargs,
)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items(): for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend( multi_modal_placeholder_maps[modality].extend(
...@@ -264,7 +272,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -264,7 +272,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
block_tables=torch.tensor([], device=self.device, dtype=torch.int), block_tables=torch.tensor([], device=self.device, dtype=torch.int),
) )
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens, return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs) multi_modal_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment