Unverified Commit cc2a77d7 authored by Andrew Sansom's avatar Andrew Sansom Committed by GitHub
Browse files

[Core] [Bugfix] Add Input Embeddings (#15428)


Signed-off-by: default avatarAndrew Sansom <andrew@protopia.ai>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatar临景 <linjing.yx@alibaba-inc.com>
Co-authored-by: default avatarBryce1010 <bryceyx@gmail.com>
Co-authored-by: default avatarNan2018 <nan@protopia.ai>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 9e2de9b9
......@@ -35,7 +35,8 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
get_sampler)
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal
......@@ -85,6 +86,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
additional fields.
"""
input_tokens: Optional[torch.Tensor] = None
inputs_embeds: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
token_types: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
......@@ -105,6 +107,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
......@@ -155,6 +158,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
......@@ -194,6 +198,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore
self.inputs_embeds = None # type: ignore
self.input_positions[0].clear() # type: ignore
self.token_types[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore
......@@ -221,6 +226,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
input_positions: Optional[List[List[int]]] = None,
token_types: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None,
......@@ -282,6 +288,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)):
self.input_tokens[seq_id].clear()
self.inputs_embeds = inputs_embeds
if input_positions:
self.input_positions = input_positions
else:
......@@ -356,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else:
self.input_tokens = input_tokens or []
self.inputs_embeds = inputs_embeds
self.input_positions = input_positions or []
self.token_types = token_types or []
self.mrope_input_positions = mrope_input_positions or None
......@@ -401,6 +410,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.lora_index_mapping = []
self.lora_prompt_mapping = []
def __repr__(self) -> str:
return (f"InterDataForSeqGroup("
f"request_id={self.request_id}, "
f"seq_ids={self.seq_ids}, "
f"is_prompt={self.is_prompt}, "
f"block_tables={self.block_tables}, "
f"computed_block_nums={self.computed_block_nums}, "
f"n_seqs={self.n_seqs}, "
f"input_tokens={self.input_tokens}, "
f"inputs_embeds.shape="
f"{getattr(self.inputs_embeds, 'shape', None)}, "
f"input_positions={self.input_positions}, "
f"token_types={self.token_types}, "
f"mrope_input_positions={self.mrope_input_positions}, "
f"seq_lens={self.seq_lens}, "
f"orig_seq_lens={self.orig_seq_lens}, "
f"query_lens={self.query_lens}, "
f"context_lens={self.context_lens}, "
f"multi_modal_kwargs={self.multi_modal_kwargs}")
def gen_inter_data_builder(self, num_seqs: int):
return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
request_id="",
......@@ -511,13 +540,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
context_len = seq_data.get_num_computed_tokens()
# Compute tokens.
tokens = seq_data.get_token_ids()[context_len:seq_len]
if seq_data.prompt_embeds is None:
tokens = seq_data.get_token_ids()[context_len:seq_len]
prompt_embeds = None
else:
tokens = [0] * (seq_len - context_len)
prompt_embeds = seq_data.get_token_embeddings(
)[context_len:seq_len]
token_types = seq_group_metadata.token_type_ids
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.inputs_embeds = prompt_embeds
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.token_types[seq_idx].extend(
token_types if token_types else [])
......@@ -822,15 +859,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
create on-device tensors.
"""
# Combine and flatten intermediate data.
input_tokens = []
token_types = []
input_tokens = list[int]()
inputs_embeds_lst = list[torch.Tensor]()
token_types = list[int]()
for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens)
for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None:
inputs_embeds_lst.append(
inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype,
device=self.runner.device))
inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_lst) == 0:
inputs_embeds = None
else:
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to(
dtype=self.runner.model_config.dtype,
device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens)
if not input_tokens:
if not input_tokens and inputs_embeds is None:
# This may happen when all prefill requests hit
# prefix caching and there is no decode request.
return self.model_input_cls()
......@@ -980,6 +1031,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
return self.model_input_cls(
input_tokens=input_tokens_tensor,
inputs_embeds=inputs_embeds,
input_positions=input_positions_tensor,
token_types=token_types_tensor,
attn_metadata=attn_metadata,
......@@ -1029,7 +1081,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.max_batchsize_to_capture = \
self.vllm_config.compilation_config.max_capture_size
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
#
self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.graph_memory_pool: Optional[Tuple[
......@@ -1466,6 +1519,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
input_positions = torch.zeros(max_batch_size,
dtype=torch.long,
device=self.device)
inputs_embeds = torch.zeros(
(max_batch_size, self.model_config.get_hidden_size()),
dtype=self.model_config.dtype,
device=self.device)
if self.model_config.uses_mrope:
input_positions = torch.tile(input_positions,
(3, 1)).cuda(device=self.device)
......@@ -1503,15 +1560,22 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# memory usage of CUDA graph.
for virtual_engine in range(
self.parallel_config.pipeline_parallel_size):
# Only rank 0 should print progress bar during capture
cudagraph_capture_sizes = (tqdm(
self.vllm_config.compilation_config.
# We need to not only iterate over batch sizes, but also whether
# to use inputs_embeds or not, hence we use the cartesian
# product.
cudagraph_capture_sizes = self.vllm_config.compilation_config\
.cudagraph_capture_sizes
cudagraph_inputs_embeds = (True, False)
compilation_cases = itertools.product(
cudagraph_capture_sizes,
desc="Capturing CUDA graph shapes",
) if get_tensor_model_parallel_rank() == 0 else
self.vllm_config.compilation_config.
cudagraph_capture_sizes)
for batch_size in cudagraph_capture_sizes:
cudagraph_inputs_embeds,
)
# Only rank 0 should print progress bar during capture
if get_tensor_model_parallel_rank() == 0:
compilation_cases = tqdm(
list(compilation_cases),
desc="Capturing CUDA graph shapes")
for batch_size, use_inputs_embeds in compilation_cases:
attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size,
......@@ -1542,6 +1606,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
capture_inputs = {
"input_ids":
input_tokens[:batch_size],
"inputs_embeds":
inputs_embeds[:batch_size]
if use_inputs_embeds else None,
"positions":
input_positions[..., :batch_size],
"intermediate_inputs":
......@@ -1578,8 +1645,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
virtual_engine):
graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = (
graph_runner)
self.graph_runners[virtual_engine][(
batch_size, use_inputs_embeds)] = graph_runner
if self.lora_config:
self._remove_dummy_loras()
......@@ -1711,8 +1778,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
use_inputs_embeds = model_input.inputs_embeds is not None
model_executable = self.graph_runners[virtual_engine][(
graph_batch_size, use_inputs_embeds)]
if previous_hidden_states is not None:
previous_hidden_states = torch.cat([
previous_hidden_states,
......@@ -1763,6 +1831,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
......@@ -1817,6 +1886,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_input.async_callback()
# Sample the next token.
assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
......@@ -1838,6 +1912,18 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs_tensor
if output.sampled_token_ids is not None:
output.sampled_token_embeds = self.model.get_input_embeddings(
output.sampled_token_ids.squeeze(1))
for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[0].output_embed = token_embed
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None
......@@ -1931,6 +2017,7 @@ class CUDAGraphRunner(nn.Module):
def capture(
self,
input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor],
......@@ -1947,6 +2034,7 @@ class CUDAGraphRunner(nn.Module):
for _ in range(_NUM_WARMUP_ITERS):
self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
positions=positions,
intermediate_tensors=intermediate_inputs,
**kwargs,
......@@ -1959,6 +2047,9 @@ class CUDAGraphRunner(nn.Module):
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
output_hidden_or_intermediate_states = self.model(
input_ids=input_ids,
**({
"inputs_embeds": inputs_embeds,
} if inputs_embeds is not None else {}),
positions=positions,
intermediate_tensors=intermediate_inputs,
**kwargs,
......@@ -1986,6 +2077,9 @@ class CUDAGraphRunner(nn.Module):
self.input_buffers = {
"input_ids":
input_ids,
**({
"inputs_embeds": inputs_embeds,
} if inputs_embeds is not None else {}),
"positions":
positions,
"kv_caches":
......@@ -2006,6 +2100,7 @@ class CUDAGraphRunner(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
**kwargs,
......@@ -2020,6 +2115,9 @@ class CUDAGraphRunner(nn.Module):
# so the shape is not padded, we need to copy partial only
self.input_buffers["positions"][:positions.shape[0]].copy_(
positions, non_blocking=True)
if inputs_embeds is not None:
self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_(
inputs_embeds, non_blocking=True)
if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_(
......
......@@ -84,10 +84,17 @@ class PoolingModelRunner(
# explore how to leverage it.
if (prefill_meta is None and decode_meta is not None
and decode_meta.use_cuda_graph):
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
if model_input.inputs_embeds is None:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
else:
model_executable = self.model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment