Commit 31330101 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-dev

parents e8933c34 dc1b4a6f
...@@ -157,13 +157,19 @@ class TPUWorker: ...@@ -157,13 +157,19 @@ class TPUWorker:
runner_kv_caches) runner_kv_caches)
self.model_runner._dummy_run( self.model_runner._dummy_run(
runner_kv_caches, self.scheduler_config.max_num_batched_tokens)
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
# Synchronize before measuring the memory usage. # Synchronize before measuring the memory usage.
xm.wait_device_ops() xm.wait_device_ops()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
self.model_runner.reset_dynamo_cache()
# Get the maximum amount of memory used by the model weights and # Get the maximum amount of memory used by the model weights and
# intermediate activations. # intermediate activations.
m = xm.get_memory_info(self.device) m = xm.get_memory_info(self.device)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch import torch
...@@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs( ...@@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs(
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
"instead. This is most likely due to incorrect implementation " "instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.") "of the model's `get_multimodal_embeddings` method.")
def scatter_mm_placeholders(
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
:class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def gather_mm_placeholders(
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of :func:`scatter_mm_placeholders`.
"""
if is_embed is None:
return placeholders
return placeholders[is_embed]
...@@ -16,6 +16,7 @@ from vllm.config import VllmConfig ...@@ -16,6 +16,7 @@ from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
...@@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import ( ...@@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import (
from vllm.worker.utils import assert_enc_dec_mr_supported_scenario from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger = init_logger(__name__) logger = init_logger(__name__)
LORA_WARMUP_RANK = 8
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
...@@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
if num_steps > 1: if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in " raise ValueError("num_steps > 1 is not supported in "
"EncoderDecoderModelRunner") "EncoderDecoderModelRunner")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if (model_input.attn_metadata is not None if (model_input.attn_metadata is not None
and model_input.attn_metadata.prefill_metadata is None and model_input.attn_metadata.prefill_metadata is None
and model_input.attn_metadata.decode_metadata.use_cuda_graph): and model_input.attn_metadata.decode_metadata.use_cuda_graph):
...@@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, and therefore the max amount of
# memory consumption. Create dummy lora request copies from the
# lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests: List[LoRARequest] = []
dummy_lora_requests_per_seq: List[LoRARequest] = []
if self.lora_config:
dummy_lora_requests = self._add_dummy_loras(
self.lora_config.max_loras)
assert len(dummy_lora_requests) == self.lora_config.max_loras
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the total # Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = [] seqs: List[SequenceGroupMetadata] = []
...@@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
block_tables=None, block_tables=None,
encoder_seq_data=encoder_dummy_data.seq_data, encoder_seq_data=encoder_dummy_data.seq_data,
cross_block_table=None, cross_block_table=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=decoder_dummy_data.multi_modal_data multi_modal_data=decoder_dummy_data.multi_modal_data
or encoder_dummy_data.multi_modal_data, or encoder_dummy_data.multi_modal_data,
multi_modal_placeholders=decoder_dummy_data. multi_modal_placeholders=decoder_dummy_data.
......
...@@ -32,6 +32,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, ...@@ -32,6 +32,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import DeviceConfig, VllmConfig from vllm.config import DeviceConfig, VllmConfig
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -44,11 +45,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput ...@@ -44,11 +45,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SequenceGroupToSample
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs) MultiModalKwargs)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SequenceData, from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
SequenceGroupMetadata) Logprob, SequenceData, SequenceGroupMetadata,
SequenceOutput)
from vllm.utils import (bind_kv_cache, is_pin_memory_available, from vllm.utils import (bind_kv_cache, is_pin_memory_available,
make_tensor_with_pad) make_tensor_with_pad)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
...@@ -100,7 +103,10 @@ def subtuple(obj: object, ...@@ -100,7 +103,10 @@ def subtuple(obj: object,
if to_override is None: if to_override is None:
to_override = {} to_override = {}
fields = set(to_copy) | set(to_override.keys()) fields = set(to_copy) | set(to_override.keys())
values = {f: to_override.get(f, getattr(obj, f)) for f in fields} if type(obj) is dict:
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE: if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename, _TYPE_CACHE[typename] = collections.namedtuple(typename,
' '.join(fields)) ' '.join(fields))
...@@ -533,6 +539,8 @@ class ModelInputForHPU(ModelRunnerInputBase): ...@@ -533,6 +539,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
virtual_engine: int = 0 virtual_engine: int = 0
lora_ids: Optional[List[int]] = None lora_ids: Optional[List[int]] = None
async_callback: Optional[Callable] = None async_callback: Optional[Callable] = None
is_first_multi_step: bool = True
is_last_step: bool = True
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
...@@ -545,6 +553,8 @@ class ModelInputForHPU(ModelRunnerInputBase): ...@@ -545,6 +553,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
"batch_size_padded": self.batch_size_padded, "batch_size_padded": self.batch_size_padded,
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"lora_ids": self.lora_ids, "lora_ids": self.lora_ids,
"is_first_multi_step": self.is_first_multi_step,
"is_last_step": self.is_last_step,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict return tensor_dict
...@@ -656,6 +666,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -656,6 +666,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self._set_gc_threshold() self._set_gc_threshold()
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
# For multi-step scheduling
self.cached_step_outputs: List[torch.Tensor] = []
def _set_gc_threshold(self) -> None: def _set_gc_threshold(self) -> None:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold # Read https://docs.python.org/3/library/gc.html#gc.set_threshold
# for comprehensive description of gc generations. # for comprehensive description of gc generations.
...@@ -1005,6 +1018,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1005,6 +1018,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
def _prepare_decode( def _prepare_decode(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
output=None,
) -> PrepareDecodeMetadata: ) -> PrepareDecodeMetadata:
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
...@@ -1035,8 +1049,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1035,8 +1049,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id() if output is None:
input_tokens.append([generation_token]) generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len() seq_len = seq_data.get_len()
position = seq_len - 1 position = seq_len - 1
...@@ -1047,6 +1062,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1047,6 +1062,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
seq_lens.append(seq_len) seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
num_fully_occupied_blocks = position // self.block_size
block_table = block_table[:num_fully_occupied_blocks + 1]
if len(block_table) == 0: if len(block_table) == 0:
block_number = _PAD_BLOCK_ID block_number = _PAD_BLOCK_ID
else: else:
...@@ -1066,9 +1084,14 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1066,9 +1084,14 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_table = block_table[-sliding_window_blocks:] block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table) block_tables.append(block_table)
input_tokens = torch.tensor(input_tokens, if output is None:
dtype=torch.long, input_tokens = torch.tensor(input_tokens,
device=self.device) dtype=torch.long,
device=self.device)
else:
real_batch_size = len(seq_group_metadata_list)
input_tokens = output[:real_batch_size]
input_positions = torch.tensor(input_positions, input_positions = torch.tensor(input_positions,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
...@@ -1462,7 +1485,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1462,7 +1485,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
profiler.start() profiler.start()
for _ in range(times): for _ in range(times):
inputs = self.prepare_model_input(seqs) inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, None, warmup_mode=True) is_single_step = \
self.vllm_config.scheduler_config.num_scheduler_steps == 1
if is_prompt or is_single_step:
self.execute_model(inputs, None, warmup_mode=True)
else: # decode with multi-step
inputs = dataclasses.replace(inputs,
is_first_multi_step=True,
is_last_step=False)
self.execute_model(inputs,
None,
warmup_mode=True,
num_steps=2,
seqs=seqs)
inputs = dataclasses.replace(inputs,
is_first_multi_step=False,
is_last_step=True)
self.execute_model(inputs,
None,
warmup_mode=True,
num_steps=2,
seqs=seqs)
torch.hpu.synchronize() torch.hpu.synchronize()
if profiler: if profiler:
profiler.step() profiler.step()
...@@ -1985,115 +2028,273 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): ...@@ -1985,115 +2028,273 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
warmup_mode=False, warmup_mode=False,
seqs=None,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1: if not model_input.is_first_multi_step:
raise ValueError( if not model_input.is_last_step:
"num_steps > 1 is not supported in HPUModelRunner") # not first or last multi-step
return []
# last multi-step
output = self._decode_sampler_outputs(
model_input) if self.is_driver_worker else []
torch.hpu.synchronize()
if model_input.is_first_multi_step:
# first multi-step
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
sampling_metadata = model_input.sampling_metadata
real_batch_size = model_input.real_batch_size
batch_size_padded = model_input.batch_size_padded
assert input_tokens is not None
assert input_positions is not None
assert sampling_metadata is not None
assert attn_metadata is not None
is_prompt = attn_metadata.is_prompt
assert is_prompt is not None
batch_size = input_tokens.size(0)
seq_len = self._seq_len(attn_metadata)
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
if self.lora_config:
assert model_input.lora_ids is not None
lora_mask, lora_logits_mask = self.create_lora_mask(
input_tokens, model_input.lora_ids,
attn_metadata.is_prompt)
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors,
"lora_mask": lora_mask,
"virtual_engine": model_input.virtual_engine,
**(model_input.multi_modal_kwargs or {}),
}
if htorch.utils.internal.is_lazy():
execute_model_kwargs.update(
{"bypass_hpu_graphs": not use_graphs})
if self.lora_config: htorch.core.mark_step()
assert model_input.lora_requests is not None if self.is_driver_worker:
assert model_input.lora_mapping is not None model_event_name = ("model_"
self.set_active_loras(model_input.lora_requests, f"{'prompt' if is_prompt else 'decode'}_"
model_input.lora_mapping) f"bs{batch_size}_"
input_tokens = model_input.input_tokens f"seq{seq_len}_"
input_positions = model_input.input_positions f"graphs{'T' if use_graphs else 'F'}")
attn_metadata = model_input.attn_metadata else:
sampling_metadata = model_input.sampling_metadata model_event_name = 'model_executable'
real_batch_size = model_input.real_batch_size if num_steps > 1:
batch_size_padded = model_input.batch_size_padded # in case of multi-step scheduling
assert input_tokens is not None # we only want to pythonize in the last step
assert input_positions is not None sampling_metadata.skip_sampler_cpu_output = True
assert sampling_metadata is not None self.model.model.sampler.include_gpu_probs_tensor = True
assert attn_metadata is not None cache_orig_output_tokens_len: List[Dict] = []
is_prompt = attn_metadata.is_prompt
assert is_prompt is not None def try_revert_dummy_output_tokens():
batch_size = input_tokens.size(0) if len(cache_orig_output_tokens_len) > 0:
seq_len = self._seq_len(attn_metadata) # Reuse the original output token ids length
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) for i, seq_group_metadata in enumerate(
self._check_config(batch_size, seq_len, is_prompt, warmup_mode) seq_group_metadata_list):
for j, data in seq_group_metadata.seq_data.items():
orig_output_tokens_len = \
cache_orig_output_tokens_len[i][j]
data.output_token_ids = \
data.output_token_ids[:orig_output_tokens_len]
for i in range(num_steps):
if i != 0 and not self.is_driver_worker:
broadcast_data = broadcast_tensor_dict(src=0)
if 'early_exit' in broadcast_data and broadcast_data[
'early_exit']:
return [output] if num_steps == 1 else []
execute_model_kwargs.update({
"input_ids":
broadcast_data["input_ids"],
"positions":
broadcast_data["positions"],
"attn_metadata":
self.trim_attn_metadata(
broadcast_data["attn_metadata"])
})
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.
selected_token_indices)
if self.lora_config:
LoraMask.setLoraMask(
lora_logits_mask.index_select(
0, sampling_metadata.selected_token_indices))
# Compute the logits.
with self.profiler.record_event(
'internal',
('compute_logits_'
f'{"prompt" if is_prompt else "decode"}_bs'
f'{batch_size}_'
f'seq{seq_len}')):
if num_steps == 1:
sampling_metadata.selected_token_indices = None
logits = self.model.compute_logits(hidden_states,
sampling_metadata)
htorch.core.mark_step()
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
continue
lora_mask: torch.Tensor = None if model_input.async_callback is not None:
lora_logits_mask: torch.Tensor = None model_input.async_callback()
if self.lora_config: # Sample the next token.
assert model_input.lora_ids is not None with self.profiler.record_event(
lora_mask, lora_logits_mask = self.create_lora_mask( 'internal', ('sample_'
input_tokens, model_input.lora_ids, attn_metadata.is_prompt) f'{"prompt" if is_prompt else "decode"}_'
f'bs{batch_size}_'
execute_model_kwargs = { f'seq{seq_len}')):
"input_ids": input_tokens, output = self.model.sample(
"positions": input_positions, logits=logits,
"attn_metadata": self.trim_attn_metadata(attn_metadata), sampling_metadata=sampling_metadata,
"intermediate_tensors": intermediate_tensors, )
"lora_mask": lora_mask, if num_steps > 1:
"virtual_engine": model_input.virtual_engine, output = output.sampled_token_ids
**(model_input.multi_modal_kwargs or {}), self.cached_step_outputs.append(
} output.detach().clone())
if htorch.utils.internal.is_lazy(): htorch.core.mark_step()
execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs}) if i < num_steps - 1:
if i == 0:
htorch.core.mark_step() if model_input.async_callback is not None:
if self.is_driver_worker: ctx = model_input.async_callback.keywords[ # type: ignore
model_event_name = ("model_" "ctx"]
f"{'prompt' if is_prompt else 'decode'}_" seq_group_metadata_list = \
f"bs{batch_size}_" ctx.seq_group_metadata_list
f"seq{seq_len}_" elif seqs is not None:
f"graphs{'T' if use_graphs else 'F'}") seq_group_metadata_list = seqs
else:
raise RuntimeError(
"seq_group_metadata_list is uninitialized")
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
# Skip empty steps
seq_group_metadata.state.current_step += (
num_steps - 2)
# Cache the original output token ids
cache_orig_output_tokens_len.append({})
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[i][j] = \
len(data.output_token_ids)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
max_output_len = sampling_metadata.seq_groups[
0].sampling_params.max_tokens
if len(data.output_token_ids) < max_output_len - 1:
# add a place holder for prepare_decode
# arbitrary value, this could be any token
dummy_token = (540, )
data.output_token_ids += (dummy_token)
else:
broadcast_tensor_dict({'early_exit': True},
src=0)
if num_steps == 1:
return [output]
else:
try_revert_dummy_output_tokens()
return []
result = self._prepare_decode(seq_group_metadata_list,
output=output)
execute_model_kwargs.update({
"input_ids":
result.input_tokens,
"positions":
result.input_positions,
"attn_metadata":
self.trim_attn_metadata(result.attn_metadata)
})
model_kwargs_broadcast_data = {
"input_ids": result.input_tokens,
"positions": result.input_positions,
"attn_metadata": vars(result.attn_metadata)
}
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
else:
try_revert_dummy_output_tokens()
if self.is_driver_worker and self.profiler.enabled:
# Stop recording 'execute_model' event
self.profiler.end()
event_end = self.profiler.get_timestamp_us()
counters = self.profiler_counter_helper.get_counter_dict(
cache_config=self.cache_config,
duration=event_end - self.event_start,
seq_len=seq_len,
batch_size_padded=batch_size_padded,
real_batch_size=real_batch_size,
is_prompt=is_prompt)
self.profiler.record_counter(self.event_start, counters)
if num_steps == 1:
return [output] if self.is_driver_worker else []
else:
return []
return output if type(output) is list else [output]
def _decode_sampler_outputs(self, model_input):
use_async_out_proc = model_input.async_callback is not None
sampler_outputs = []
num_outputs = len(self.cached_step_outputs)
for i in range(num_outputs):
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = self._make_decode_output(
next_token_ids, model_input.sampling_metadata.seq_groups)
sampler_outputs.append(sampler_output)
if i < num_outputs - 1 and use_async_out_proc:
assert model_input.async_callback is not None
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
ctx.append_output(
outputs=[sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False,
is_first_step_output=False)
model_input.async_callback()
if use_async_out_proc:
return [sampler_outputs[-1]]
else: else:
model_event_name = 'model_executable' return sampler_outputs
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.selected_token_indices
)
if self.lora_config: def _make_decode_output(
LoraMask.setLoraMask( self,
lora_logits_mask.index_select( next_token_ids: List[List[int]],
0, sampling_metadata.selected_token_indices)) seq_groups: List[SequenceGroupToSample],
) -> SamplerOutput:
# Compute the logits. zero_logprob = Logprob(0.0)
with self.profiler.record_event( sampler_outputs = []
'internal', ('compute_logits_' batch_idx = 0
f'{"prompt" if is_prompt else "decode"}_bs' for seq_group in seq_groups:
f'{batch_size}_' seq_ids = seq_group.seq_ids
f'seq{seq_len}')): seq_outputs = []
sampling_metadata.selected_token_indices = None for seq_id in seq_ids:
logits = self.model.compute_logits(hidden_states, next_token_id = next_token_ids[batch_idx][0]
sampling_metadata) seq_outputs.append(
htorch.core.mark_step() SequenceOutput(seq_id, next_token_id,
# Only perform sampling in the driver worker. {next_token_id: zero_logprob}))
if not self.is_driver_worker: batch_idx += 1
return [] sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
if model_input.async_callback is not None: return SamplerOutput(sampler_outputs)
model_input.async_callback()
# Sample the next token.
with self.profiler.record_event(
'internal', ('sample_'
f'{"prompt" if is_prompt else "decode"}_'
f'bs{batch_size}_'
f'seq{seq_len}')):
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
output.outputs = output.outputs[:real_batch_size]
htorch.core.mark_step()
if self.is_driver_worker and self.profiler.enabled:
# Stop recording 'execute_model' event
self.profiler.end()
event_end = self.profiler.get_timestamp_us()
counters = self.profiler_counter_helper.get_counter_dict(
cache_config=self.cache_config,
duration=event_end - self.event_start,
seq_len=seq_len,
batch_size_padded=batch_size_padded,
real_batch_size=real_batch_size,
is_prompt=is_prompt)
self.profiler.record_counter(self.event_start, counters)
return [output]
def shutdown_inc(self): def shutdown_inc(self):
can_finalize_inc = False can_finalize_inc = False
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import sys
import dataclasses import dataclasses
import gc import gc
import inspect import inspect
...@@ -15,7 +16,7 @@ import numpy as np ...@@ -15,7 +16,7 @@ import numpy as np
import torch import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm.auto import tqdm
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
...@@ -1108,6 +1109,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1108,6 +1109,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if hasattr(self, "_builder_cls"): if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls` # multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self)) self.builder = self._builder_cls(weakref.proxy(self))
self.enforce_eager_bs_threshould = sys.maxsize
if envs.VLLM_ENFORCE_EAGER_BS_THRESHOLD is not None and envs.VLLM_ENFORCE_EAGER_BS_THRESHOLD > 0:
self.enforce_eager_bs_threshould = envs.VLLM_ENFORCE_EAGER_BS_THRESHOLD
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
...@@ -1717,7 +1722,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1717,7 +1722,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
# virtual engines share the same kv cache. # virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine virtual_engine = model_input.virtual_engine
previous_hidden_states = kwargs.get("previous_hidden_states") previous_hidden_states = kwargs.get("previous_hidden_states")
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph and \
model_input.input_tokens.shape[0] <= self.enforce_eager_bs_threshould:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][ model_executable = self.graph_runners[virtual_engine][
......
# SPDX-License-Identifier: Apache-2.0
###############################################################################
# Copyright (C) 2025 Habana Labs, Ltd. an Intel Company
###############################################################################
import dataclasses
from typing import Dict, Optional, Tuple
import torch
from vllm.distributed import broadcast_tensor_dict
from vllm.sequence import ExecuteModelRequest
from vllm.worker.hpu_model_runner import ModelInputForHPU
from vllm.worker.hpu_worker import HPUWorker
from vllm.worker.worker_base import WorkerInput
class MultiStepHPUWorker(HPUWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cached_model_input: Optional[ModelInputForHPU] = None
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelInputForHPU, WorkerInput, Dict[str, torch.Tensor]]:
"""
Get the driver input and broadcast it to other workers.
"""
assert self.is_driver_worker
assert execute_model_req.virtual_engine == 0
is_first_multi_step = execute_model_req.is_first_multi_step
is_last_step = execute_model_req.is_last_step
if is_first_multi_step:
# on first step we prepare the worker input and model input normally
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
worker_input = dataclasses.replace(
worker_input,
num_steps=execute_model_req.num_lookahead_slots + 1)
model_input: ModelInputForHPU = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if execute_model_req.async_callback:
model_input = dataclasses.replace(
model_input,
async_callback=execute_model_req.async_callback)
else:
# on subsequent steps we reuse the worker input and model input
assert self.cached_model_input is not None
model_input = self.cached_model_input
worker_input = WorkerInput()
model_input = dataclasses.replace(
model_input,
is_first_multi_step=is_first_multi_step,
is_last_step=is_last_step)
if self.do_metadata_broadcast:
if is_first_multi_step:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
else:
broadcast_data = {
"is_first_multi_step": is_first_multi_step,
"is_last_step": is_last_step,
}
broadcast_tensor_dict(broadcast_data, src=0)
# Returning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return model_input, worker_input, {}
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[ModelInputForHPU, WorkerInput, Dict[str,
torch.Tensor]]]:
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
execute_model_req)
if model_input.is_first_multi_step:
self.cached_model_input = model_input
return model_input, worker_input, {}
else:
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
if len(broadcast_data) == 2:
assert self.cached_model_input is not None
self.cached_model_input = dataclasses.replace(
self.cached_model_input,
is_first_multi_step=broadcast_data["is_first_multi_step"],
is_last_step=broadcast_data["is_last_step"])
empty_worker_input = WorkerInput()
return self.cached_model_input, empty_worker_input, {}
worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
self.cached_model_input = model_input
return model_input, worker_input, {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment