Unverified Commit cc2a77d7 authored by Andrew Sansom's avatar Andrew Sansom Committed by GitHub
Browse files

[Core] [Bugfix] Add Input Embeddings (#15428)


Signed-off-by: default avatarAndrew Sansom <andrew@protopia.ai>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatar临景 <linjing.yx@alibaba-inc.com>
Co-authored-by: default avatarBryce1010 <bryceyx@gmail.com>
Co-authored-by: default avatarNan2018 <nan@protopia.ai>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 9e2de9b9
...@@ -35,7 +35,8 @@ from vllm.lora.request import LoRARequest ...@@ -35,7 +35,8 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
get_sampler)
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models import supports_lora, supports_multimodal
...@@ -85,6 +86,7 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -85,6 +86,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
additional fields. additional fields.
""" """
input_tokens: Optional[torch.Tensor] = None input_tokens: Optional[torch.Tensor] = None
inputs_embeds: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
token_types: Optional[torch.Tensor] = None token_types: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None seq_lens: Optional[List[int]] = None
...@@ -105,6 +107,7 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -105,6 +107,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
"input_tokens": self.input_tokens, "input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions, "input_positions": self.input_positions,
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
...@@ -155,6 +158,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): ...@@ -155,6 +158,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
"input_tokens": self.input_tokens, "input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions, "input_positions": self.input_positions,
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
...@@ -194,6 +198,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -194,6 +198,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def simple_reinit(self): def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore self.input_tokens[0].clear() # type: ignore
self.inputs_embeds = None # type: ignore
self.input_positions[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore
self.token_types[0].clear() # type: ignore self.token_types[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore self.mrope_input_positions = None # type: ignore
...@@ -221,6 +226,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -221,6 +226,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Input tokens and positions. # Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None, input_tokens: Optional[List[List[int]]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
input_positions: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None,
token_types: Optional[List[List[int]]] = None, token_types: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None,
...@@ -282,6 +288,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -282,6 +288,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)): for seq_id in range(len(self.seq_ids)):
self.input_tokens[seq_id].clear() self.input_tokens[seq_id].clear()
self.inputs_embeds = inputs_embeds
if input_positions: if input_positions:
self.input_positions = input_positions self.input_positions = input_positions
else: else:
...@@ -356,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -356,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else: else:
self.input_tokens = input_tokens or [] self.input_tokens = input_tokens or []
self.inputs_embeds = inputs_embeds
self.input_positions = input_positions or [] self.input_positions = input_positions or []
self.token_types = token_types or [] self.token_types = token_types or []
self.mrope_input_positions = mrope_input_positions or None self.mrope_input_positions = mrope_input_positions or None
...@@ -401,6 +410,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -401,6 +410,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.lora_index_mapping = [] self.lora_index_mapping = []
self.lora_prompt_mapping = [] self.lora_prompt_mapping = []
def __repr__(self) -> str:
return (f"InterDataForSeqGroup("
f"request_id={self.request_id}, "
f"seq_ids={self.seq_ids}, "
f"is_prompt={self.is_prompt}, "
f"block_tables={self.block_tables}, "
f"computed_block_nums={self.computed_block_nums}, "
f"n_seqs={self.n_seqs}, "
f"input_tokens={self.input_tokens}, "
f"inputs_embeds.shape="
f"{getattr(self.inputs_embeds, 'shape', None)}, "
f"input_positions={self.input_positions}, "
f"token_types={self.token_types}, "
f"mrope_input_positions={self.mrope_input_positions}, "
f"seq_lens={self.seq_lens}, "
f"orig_seq_lens={self.orig_seq_lens}, "
f"query_lens={self.query_lens}, "
f"context_lens={self.context_lens}, "
f"multi_modal_kwargs={self.multi_modal_kwargs}")
def gen_inter_data_builder(self, num_seqs: int): def gen_inter_data_builder(self, num_seqs: int):
return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup( return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
request_id="", request_id="",
...@@ -511,13 +540,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -511,13 +540,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
context_len = seq_data.get_num_computed_tokens() context_len = seq_data.get_num_computed_tokens()
# Compute tokens. # Compute tokens.
if seq_data.prompt_embeds is None:
tokens = seq_data.get_token_ids()[context_len:seq_len] tokens = seq_data.get_token_ids()[context_len:seq_len]
prompt_embeds = None
else:
tokens = [0] * (seq_len - context_len)
prompt_embeds = seq_data.get_token_embeddings(
)[context_len:seq_len]
token_types = seq_group_metadata.token_type_ids token_types = seq_group_metadata.token_type_ids
inter_data.seq_lens[seq_idx] = seq_len inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.inputs_embeds = prompt_embeds
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.token_types[seq_idx].extend( inter_data.token_types[seq_idx].extend(
token_types if token_types else []) token_types if token_types else [])
...@@ -822,15 +859,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -822,15 +859,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
create on-device tensors. create on-device tensors.
""" """
# Combine and flatten intermediate data. # Combine and flatten intermediate data.
input_tokens = [] input_tokens = list[int]()
token_types = [] inputs_embeds_lst = list[torch.Tensor]()
token_types = list[int]()
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens: for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens) input_tokens.extend(cur_input_tokens)
for cur_token_types in inter_data.token_types: for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types) token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None:
inputs_embeds_lst.append(
inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype,
device=self.runner.device))
inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_lst) == 0:
inputs_embeds = None
else:
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to(
dtype=self.runner.model_config.dtype,
device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens)
if not input_tokens: if not input_tokens and inputs_embeds is None:
# This may happen when all prefill requests hit # This may happen when all prefill requests hit
# prefix caching and there is no decode request. # prefix caching and there is no decode request.
return self.model_input_cls() return self.model_input_cls()
...@@ -980,6 +1031,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -980,6 +1031,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
return self.model_input_cls( return self.model_input_cls(
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
inputs_embeds=inputs_embeds,
input_positions=input_positions_tensor, input_positions=input_positions_tensor,
token_types=token_types_tensor, token_types=token_types_tensor,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
...@@ -1029,7 +1081,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1029,7 +1081,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.max_batchsize_to_capture = \ self.max_batchsize_to_capture = \
self.vllm_config.compilation_config.max_capture_size self.vllm_config.compilation_config.max_capture_size
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ #
self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size) {} for _ in range(self.parallel_config.pipeline_parallel_size)
] ]
self.graph_memory_pool: Optional[Tuple[ self.graph_memory_pool: Optional[Tuple[
...@@ -1466,6 +1519,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1466,6 +1519,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
input_positions = torch.zeros(max_batch_size, input_positions = torch.zeros(max_batch_size,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
inputs_embeds = torch.zeros(
(max_batch_size, self.model_config.get_hidden_size()),
dtype=self.model_config.dtype,
device=self.device)
if self.model_config.uses_mrope: if self.model_config.uses_mrope:
input_positions = torch.tile(input_positions, input_positions = torch.tile(input_positions,
(3, 1)).cuda(device=self.device) (3, 1)).cuda(device=self.device)
...@@ -1503,15 +1560,22 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1503,15 +1560,22 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# memory usage of CUDA graph. # memory usage of CUDA graph.
for virtual_engine in range( for virtual_engine in range(
self.parallel_config.pipeline_parallel_size): self.parallel_config.pipeline_parallel_size):
# Only rank 0 should print progress bar during capture # We need to not only iterate over batch sizes, but also whether
cudagraph_capture_sizes = (tqdm( # to use inputs_embeds or not, hence we use the cartesian
self.vllm_config.compilation_config. # product.
cudagraph_capture_sizes = self.vllm_config.compilation_config\
.cudagraph_capture_sizes
cudagraph_inputs_embeds = (True, False)
compilation_cases = itertools.product(
cudagraph_capture_sizes, cudagraph_capture_sizes,
desc="Capturing CUDA graph shapes", cudagraph_inputs_embeds,
) if get_tensor_model_parallel_rank() == 0 else )
self.vllm_config.compilation_config. # Only rank 0 should print progress bar during capture
cudagraph_capture_sizes) if get_tensor_model_parallel_rank() == 0:
for batch_size in cudagraph_capture_sizes: compilation_cases = tqdm(
list(compilation_cases),
desc="Capturing CUDA graph shapes")
for batch_size, use_inputs_embeds in compilation_cases:
attn_metadata = ( attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch( self.attn_state.graph_capture_get_metadata_for_batch(
batch_size, batch_size,
...@@ -1542,6 +1606,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1542,6 +1606,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
capture_inputs = { capture_inputs = {
"input_ids": "input_ids":
input_tokens[:batch_size], input_tokens[:batch_size],
"inputs_embeds":
inputs_embeds[:batch_size]
if use_inputs_embeds else None,
"positions": "positions":
input_positions[..., :batch_size], input_positions[..., :batch_size],
"intermediate_inputs": "intermediate_inputs":
...@@ -1578,8 +1645,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1578,8 +1645,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
virtual_engine): virtual_engine):
graph_runner.capture(**capture_inputs) graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool() self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = ( self.graph_runners[virtual_engine][(
graph_runner) batch_size, use_inputs_embeds)] = graph_runner
if self.lora_config: if self.lora_config:
self._remove_dummy_loras() self._remove_dummy_loras()
...@@ -1711,8 +1778,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1711,8 +1778,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][ use_inputs_embeds = model_input.inputs_embeds is not None
graph_batch_size] model_executable = self.graph_runners[virtual_engine][(
graph_batch_size, use_inputs_embeds)]
if previous_hidden_states is not None: if previous_hidden_states is not None:
previous_hidden_states = torch.cat([ previous_hidden_states = torch.cat([
previous_hidden_states, previous_hidden_states,
...@@ -1763,6 +1831,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1763,6 +1831,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.vllm_config, virtual_engine): self.vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions, positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
...@@ -1817,6 +1886,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1817,6 +1886,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_input.async_callback() model_input.async_callback()
# Sample the next token. # Sample the next token.
assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True
output: SamplerOutput = self.sampler( output: SamplerOutput = self.sampler(
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
...@@ -1838,6 +1912,18 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1838,6 +1912,18 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
output.model_forward_time = (orig_model_forward_time + output.model_forward_time = (orig_model_forward_time +
model_forward_time) model_forward_time)
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs_tensor
if output.sampled_token_ids is not None:
output.sampled_token_embeds = self.model.get_input_embeddings(
output.sampled_token_ids.squeeze(1))
for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[0].output_embed = token_embed
if self.return_hidden_states: if self.return_hidden_states:
# we only need to pass hidden states of most recent token # we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None assert model_input.sampling_metadata is not None
...@@ -1931,6 +2017,7 @@ class CUDAGraphRunner(nn.Module): ...@@ -1931,6 +2017,7 @@ class CUDAGraphRunner(nn.Module):
def capture( def capture(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
intermediate_inputs: Optional[IntermediateTensors], intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
...@@ -1947,6 +2034,7 @@ class CUDAGraphRunner(nn.Module): ...@@ -1947,6 +2034,7 @@ class CUDAGraphRunner(nn.Module):
for _ in range(_NUM_WARMUP_ITERS): for _ in range(_NUM_WARMUP_ITERS):
self.model( self.model(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds,
positions=positions, positions=positions,
intermediate_tensors=intermediate_inputs, intermediate_tensors=intermediate_inputs,
**kwargs, **kwargs,
...@@ -1959,6 +2047,9 @@ class CUDAGraphRunner(nn.Module): ...@@ -1959,6 +2047,9 @@ class CUDAGraphRunner(nn.Module):
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
output_hidden_or_intermediate_states = self.model( output_hidden_or_intermediate_states = self.model(
input_ids=input_ids, input_ids=input_ids,
**({
"inputs_embeds": inputs_embeds,
} if inputs_embeds is not None else {}),
positions=positions, positions=positions,
intermediate_tensors=intermediate_inputs, intermediate_tensors=intermediate_inputs,
**kwargs, **kwargs,
...@@ -1986,6 +2077,9 @@ class CUDAGraphRunner(nn.Module): ...@@ -1986,6 +2077,9 @@ class CUDAGraphRunner(nn.Module):
self.input_buffers = { self.input_buffers = {
"input_ids": "input_ids":
input_ids, input_ids,
**({
"inputs_embeds": inputs_embeds,
} if inputs_embeds is not None else {}),
"positions": "positions":
positions, positions,
"kv_caches": "kv_caches":
...@@ -2006,6 +2100,7 @@ class CUDAGraphRunner(nn.Module): ...@@ -2006,6 +2100,7 @@ class CUDAGraphRunner(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
**kwargs, **kwargs,
...@@ -2020,6 +2115,9 @@ class CUDAGraphRunner(nn.Module): ...@@ -2020,6 +2115,9 @@ class CUDAGraphRunner(nn.Module):
# so the shape is not padded, we need to copy partial only # so the shape is not padded, we need to copy partial only
self.input_buffers["positions"][:positions.shape[0]].copy_( self.input_buffers["positions"][:positions.shape[0]].copy_(
positions, non_blocking=True) positions, non_blocking=True)
if inputs_embeds is not None:
self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_(
inputs_embeds, non_blocking=True)
if self.backend_name != "NO_ATTENTION": if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_( self.input_buffers["slot_mapping"].copy_(
......
...@@ -84,10 +84,17 @@ class PoolingModelRunner( ...@@ -84,10 +84,17 @@ class PoolingModelRunner(
# explore how to leverage it. # explore how to leverage it.
if (prefill_meta is None and decode_meta is not None if (prefill_meta is None and decode_meta is not None
and decode_meta.use_cuda_graph): and decode_meta.use_cuda_graph):
if model_input.inputs_embeds is None:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][ model_executable = (
graph_batch_size] self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
else: else:
model_executable = self.model model_executable = self.model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment