Commit 31330101 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-dev

parents e8933c34 dc1b4a6f
......@@ -157,13 +157,19 @@ class TPUWorker:
runner_kv_caches)
self.model_runner._dummy_run(
runner_kv_caches,
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
self.scheduler_config.max_num_batched_tokens)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
self.model_runner.reset_dynamo_cache()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m = xm.get_memory_info(self.device)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
......@@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs(
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
def scatter_mm_placeholders(
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
:class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def gather_mm_placeholders(
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of :func:`scatter_mm_placeholders`.
"""
if is_embed is None:
return placeholders
return placeholders[is_embed]
......@@ -16,6 +16,7 @@ from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
......@@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import (
from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger = init_logger(__name__)
LORA_WARMUP_RANK = 8
@dataclasses.dataclass(frozen=True)
......@@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in "
"EncoderDecoderModelRunner")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if (model_input.attn_metadata is not None
and model_input.attn_metadata.prefill_metadata is None
and model_input.attn_metadata.decode_metadata.use_cuda_graph):
......@@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, and therefore the max amount of
# memory consumption. Create dummy lora request copies from the
# lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests: List[LoRARequest] = []
dummy_lora_requests_per_seq: List[LoRARequest] = []
if self.lora_config:
dummy_lora_requests = self._add_dummy_loras(
self.lora_config.max_loras)
assert len(dummy_lora_requests) == self.lora_config.max_loras
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
......@@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
block_tables=None,
encoder_seq_data=encoder_dummy_data.seq_data,
cross_block_table=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=decoder_dummy_data.multi_modal_data
or encoder_dummy_data.multi_modal_data,
multi_modal_placeholders=decoder_dummy_data.
......
......@@ -32,6 +32,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import DeviceConfig, VllmConfig
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
......@@ -44,11 +45,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SequenceGroupToSample
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceData, SequenceGroupMetadata,
SequenceOutput)
from vllm.utils import (bind_kv_cache, is_pin_memory_available,
make_tensor_with_pad)
from vllm.worker.model_runner_base import (
......@@ -100,7 +103,10 @@ def subtuple(obj: object,
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if type(obj) is dict:
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename,
' '.join(fields))
......@@ -533,6 +539,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
virtual_engine: int = 0
lora_ids: Optional[List[int]] = None
async_callback: Optional[Callable] = None
is_first_multi_step: bool = True
is_last_step: bool = True
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
......@@ -545,6 +553,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
"batch_size_padded": self.batch_size_padded,
"virtual_engine": self.virtual_engine,
"lora_ids": self.lora_ids,
"is_first_multi_step": self.is_first_multi_step,
"is_last_step": self.is_last_step,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
......@@ -656,6 +666,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self._set_gc_threshold()
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
# For multi-step scheduling
self.cached_step_outputs: List[torch.Tensor] = []
def _set_gc_threshold(self) -> None:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
# for comprehensive description of gc generations.
......@@ -1005,6 +1018,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
output=None,
) -> PrepareDecodeMetadata:
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
......@@ -1035,8 +1049,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
if output is None:
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len()
position = seq_len - 1
......@@ -1047,6 +1062,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
num_fully_occupied_blocks = position // self.block_size
block_table = block_table[:num_fully_occupied_blocks + 1]
if len(block_table) == 0:
block_number = _PAD_BLOCK_ID
else:
......@@ -1066,9 +1084,14 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
if output is None:
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
else:
real_batch_size = len(seq_group_metadata_list)
input_tokens = output[:real_batch_size]
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
......@@ -1462,7 +1485,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
profiler.start()
for _ in range(times):
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, None, warmup_mode=True)
is_single_step = \
self.vllm_config.scheduler_config.num_scheduler_steps == 1
if is_prompt or is_single_step:
self.execute_model(inputs, None, warmup_mode=True)
else: # decode with multi-step
inputs = dataclasses.replace(inputs,
is_first_multi_step=True,
is_last_step=False)
self.execute_model(inputs,
None,
warmup_mode=True,
num_steps=2,
seqs=seqs)
inputs = dataclasses.replace(inputs,
is_first_multi_step=False,
is_last_step=True)
self.execute_model(inputs,
None,
warmup_mode=True,
num_steps=2,
seqs=seqs)
torch.hpu.synchronize()
if profiler:
profiler.step()
......@@ -1985,115 +2028,273 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
warmup_mode=False,
seqs=None,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"num_steps > 1 is not supported in HPUModelRunner")
if not model_input.is_first_multi_step:
if not model_input.is_last_step:
# not first or last multi-step
return []
# last multi-step
output = self._decode_sampler_outputs(
model_input) if self.is_driver_worker else []
torch.hpu.synchronize()
if model_input.is_first_multi_step:
# first multi-step
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
sampling_metadata = model_input.sampling_metadata
real_batch_size = model_input.real_batch_size
batch_size_padded = model_input.batch_size_padded
assert input_tokens is not None
assert input_positions is not None
assert sampling_metadata is not None
assert attn_metadata is not None
is_prompt = attn_metadata.is_prompt
assert is_prompt is not None
batch_size = input_tokens.size(0)
seq_len = self._seq_len(attn_metadata)
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
if self.lora_config:
assert model_input.lora_ids is not None
lora_mask, lora_logits_mask = self.create_lora_mask(
input_tokens, model_input.lora_ids,
attn_metadata.is_prompt)
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors,
"lora_mask": lora_mask,
"virtual_engine": model_input.virtual_engine,
**(model_input.multi_modal_kwargs or {}),
}
if htorch.utils.internal.is_lazy():
execute_model_kwargs.update(
{"bypass_hpu_graphs": not use_graphs})
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
sampling_metadata = model_input.sampling_metadata
real_batch_size = model_input.real_batch_size
batch_size_padded = model_input.batch_size_padded
assert input_tokens is not None
assert input_positions is not None
assert sampling_metadata is not None
assert attn_metadata is not None
is_prompt = attn_metadata.is_prompt
assert is_prompt is not None
batch_size = input_tokens.size(0)
seq_len = self._seq_len(attn_metadata)
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)
htorch.core.mark_step()
if self.is_driver_worker:
model_event_name = ("model_"
f"{'prompt' if is_prompt else 'decode'}_"
f"bs{batch_size}_"
f"seq{seq_len}_"
f"graphs{'T' if use_graphs else 'F'}")
else:
model_event_name = 'model_executable'
if num_steps > 1:
# in case of multi-step scheduling
# we only want to pythonize in the last step
sampling_metadata.skip_sampler_cpu_output = True
self.model.model.sampler.include_gpu_probs_tensor = True
cache_orig_output_tokens_len: List[Dict] = []
def try_revert_dummy_output_tokens():
if len(cache_orig_output_tokens_len) > 0:
# Reuse the original output token ids length
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
for j, data in seq_group_metadata.seq_data.items():
orig_output_tokens_len = \
cache_orig_output_tokens_len[i][j]
data.output_token_ids = \
data.output_token_ids[:orig_output_tokens_len]
for i in range(num_steps):
if i != 0 and not self.is_driver_worker:
broadcast_data = broadcast_tensor_dict(src=0)
if 'early_exit' in broadcast_data and broadcast_data[
'early_exit']:
return [output] if num_steps == 1 else []
execute_model_kwargs.update({
"input_ids":
broadcast_data["input_ids"],
"positions":
broadcast_data["positions"],
"attn_metadata":
self.trim_attn_metadata(
broadcast_data["attn_metadata"])
})
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.
selected_token_indices)
if self.lora_config:
LoraMask.setLoraMask(
lora_logits_mask.index_select(
0, sampling_metadata.selected_token_indices))
# Compute the logits.
with self.profiler.record_event(
'internal',
('compute_logits_'
f'{"prompt" if is_prompt else "decode"}_bs'
f'{batch_size}_'
f'seq{seq_len}')):
if num_steps == 1:
sampling_metadata.selected_token_indices = None
logits = self.model.compute_logits(hidden_states,
sampling_metadata)
htorch.core.mark_step()
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
continue
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
if self.lora_config:
assert model_input.lora_ids is not None
lora_mask, lora_logits_mask = self.create_lora_mask(
input_tokens, model_input.lora_ids, attn_metadata.is_prompt)
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors,
"lora_mask": lora_mask,
"virtual_engine": model_input.virtual_engine,
**(model_input.multi_modal_kwargs or {}),
}
if htorch.utils.internal.is_lazy():
execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs})
htorch.core.mark_step()
if self.is_driver_worker:
model_event_name = ("model_"
f"{'prompt' if is_prompt else 'decode'}_"
f"bs{batch_size}_"
f"seq{seq_len}_"
f"graphs{'T' if use_graphs else 'F'}")
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
with self.profiler.record_event(
'internal', ('sample_'
f'{"prompt" if is_prompt else "decode"}_'
f'bs{batch_size}_'
f'seq{seq_len}')):
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
if num_steps > 1:
output = output.sampled_token_ids
self.cached_step_outputs.append(
output.detach().clone())
htorch.core.mark_step()
if i < num_steps - 1:
if i == 0:
if model_input.async_callback is not None:
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
seq_group_metadata_list = \
ctx.seq_group_metadata_list
elif seqs is not None:
seq_group_metadata_list = seqs
else:
raise RuntimeError(
"seq_group_metadata_list is uninitialized")
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
# Skip empty steps
seq_group_metadata.state.current_step += (
num_steps - 2)
# Cache the original output token ids
cache_orig_output_tokens_len.append({})
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[i][j] = \
len(data.output_token_ids)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
max_output_len = sampling_metadata.seq_groups[
0].sampling_params.max_tokens
if len(data.output_token_ids) < max_output_len - 1:
# add a place holder for prepare_decode
# arbitrary value, this could be any token
dummy_token = (540, )
data.output_token_ids += (dummy_token)
else:
broadcast_tensor_dict({'early_exit': True},
src=0)
if num_steps == 1:
return [output]
else:
try_revert_dummy_output_tokens()
return []
result = self._prepare_decode(seq_group_metadata_list,
output=output)
execute_model_kwargs.update({
"input_ids":
result.input_tokens,
"positions":
result.input_positions,
"attn_metadata":
self.trim_attn_metadata(result.attn_metadata)
})
model_kwargs_broadcast_data = {
"input_ids": result.input_tokens,
"positions": result.input_positions,
"attn_metadata": vars(result.attn_metadata)
}
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
else:
try_revert_dummy_output_tokens()
if self.is_driver_worker and self.profiler.enabled:
# Stop recording 'execute_model' event
self.profiler.end()
event_end = self.profiler.get_timestamp_us()
counters = self.profiler_counter_helper.get_counter_dict(
cache_config=self.cache_config,
duration=event_end - self.event_start,
seq_len=seq_len,
batch_size_padded=batch_size_padded,
real_batch_size=real_batch_size,
is_prompt=is_prompt)
self.profiler.record_counter(self.event_start, counters)
if num_steps == 1:
return [output] if self.is_driver_worker else []
else:
return []
return output if type(output) is list else [output]
def _decode_sampler_outputs(self, model_input):
use_async_out_proc = model_input.async_callback is not None
sampler_outputs = []
num_outputs = len(self.cached_step_outputs)
for i in range(num_outputs):
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = self._make_decode_output(
next_token_ids, model_input.sampling_metadata.seq_groups)
sampler_outputs.append(sampler_output)
if i < num_outputs - 1 and use_async_out_proc:
assert model_input.async_callback is not None
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
ctx.append_output(
outputs=[sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False,
is_first_step_output=False)
model_input.async_callback()
if use_async_out_proc:
return [sampler_outputs[-1]]
else:
model_event_name = 'model_executable'
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.selected_token_indices
)
return sampler_outputs
if self.lora_config:
LoraMask.setLoraMask(
lora_logits_mask.index_select(
0, sampling_metadata.selected_token_indices))
# Compute the logits.
with self.profiler.record_event(
'internal', ('compute_logits_'
f'{"prompt" if is_prompt else "decode"}_bs'
f'{batch_size}_'
f'seq{seq_len}')):
sampling_metadata.selected_token_indices = None
logits = self.model.compute_logits(hidden_states,
sampling_metadata)
htorch.core.mark_step()
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
with self.profiler.record_event(
'internal', ('sample_'
f'{"prompt" if is_prompt else "decode"}_'
f'bs{batch_size}_'
f'seq{seq_len}')):
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
output.outputs = output.outputs[:real_batch_size]
htorch.core.mark_step()
if self.is_driver_worker and self.profiler.enabled:
# Stop recording 'execute_model' event
self.profiler.end()
event_end = self.profiler.get_timestamp_us()
counters = self.profiler_counter_helper.get_counter_dict(
cache_config=self.cache_config,
duration=event_end - self.event_start,
seq_len=seq_len,
batch_size_padded=batch_size_padded,
real_batch_size=real_batch_size,
is_prompt=is_prompt)
self.profiler.record_counter(self.event_start, counters)
return [output]
def _make_decode_output(
self,
next_token_ids: List[List[int]],
seq_groups: List[SequenceGroupToSample],
) -> SamplerOutput:
zero_logprob = Logprob(0.0)
sampler_outputs = []
batch_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
seq_outputs = []
for seq_id in seq_ids:
next_token_id = next_token_ids[batch_idx][0]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return SamplerOutput(sampler_outputs)
def shutdown_inc(self):
can_finalize_inc = False
......
# SPDX-License-Identifier: Apache-2.0
import sys
import dataclasses
import gc
import inspect
......@@ -15,7 +16,7 @@ import numpy as np
import torch
import torch.distributed
import torch.nn as nn
from tqdm import tqdm
from tqdm.auto import tqdm
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
......@@ -1108,6 +1109,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))
self.enforce_eager_bs_threshould = sys.maxsize
if envs.VLLM_ENFORCE_EAGER_BS_THRESHOLD is not None and envs.VLLM_ENFORCE_EAGER_BS_THRESHOLD > 0:
self.enforce_eager_bs_threshould = envs.VLLM_ENFORCE_EAGER_BS_THRESHOLD
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
......@@ -1717,7 +1722,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
previous_hidden_states = kwargs.get("previous_hidden_states")
if prefill_meta is None and decode_meta.use_cuda_graph:
if prefill_meta is None and decode_meta.use_cuda_graph and \
model_input.input_tokens.shape[0] <= self.enforce_eager_bs_threshould:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
......
# SPDX-License-Identifier: Apache-2.0
###############################################################################
# Copyright (C) 2025 Habana Labs, Ltd. an Intel Company
###############################################################################
import dataclasses
from typing import Dict, Optional, Tuple
import torch
from vllm.distributed import broadcast_tensor_dict
from vllm.sequence import ExecuteModelRequest
from vllm.worker.hpu_model_runner import ModelInputForHPU
from vllm.worker.hpu_worker import HPUWorker
from vllm.worker.worker_base import WorkerInput
class MultiStepHPUWorker(HPUWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cached_model_input: Optional[ModelInputForHPU] = None
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelInputForHPU, WorkerInput, Dict[str, torch.Tensor]]:
"""
Get the driver input and broadcast it to other workers.
"""
assert self.is_driver_worker
assert execute_model_req.virtual_engine == 0
is_first_multi_step = execute_model_req.is_first_multi_step
is_last_step = execute_model_req.is_last_step
if is_first_multi_step:
# on first step we prepare the worker input and model input normally
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
worker_input = dataclasses.replace(
worker_input,
num_steps=execute_model_req.num_lookahead_slots + 1)
model_input: ModelInputForHPU = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if execute_model_req.async_callback:
model_input = dataclasses.replace(
model_input,
async_callback=execute_model_req.async_callback)
else:
# on subsequent steps we reuse the worker input and model input
assert self.cached_model_input is not None
model_input = self.cached_model_input
worker_input = WorkerInput()
model_input = dataclasses.replace(
model_input,
is_first_multi_step=is_first_multi_step,
is_last_step=is_last_step)
if self.do_metadata_broadcast:
if is_first_multi_step:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
else:
broadcast_data = {
"is_first_multi_step": is_first_multi_step,
"is_last_step": is_last_step,
}
broadcast_tensor_dict(broadcast_data, src=0)
# Returning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return model_input, worker_input, {}
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[ModelInputForHPU, WorkerInput, Dict[str,
torch.Tensor]]]:
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
execute_model_req)
if model_input.is_first_multi_step:
self.cached_model_input = model_input
return model_input, worker_input, {}
else:
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
if len(broadcast_data) == 2:
assert self.cached_model_input is not None
self.cached_model_input = dataclasses.replace(
self.cached_model_input,
is_first_multi_step=broadcast_data["is_first_multi_step"],
is_last_step=broadcast_data["is_last_step"])
empty_worker_input = WorkerInput()
return self.cached_model_input, empty_worker_input, {}
worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
self.cached_model_input = model_input
return model_input, worker_input, {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment