Unverified Commit 4de71463 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 deprecation] Remove V0 HPU backend (#21131)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent ac9fb732
......@@ -116,23 +116,6 @@ def rocm_platform_plugin() -> Optional[str]:
return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None
def hpu_platform_plugin() -> Optional[str]:
is_hpu = False
logger.debug("Checking if HPU platform is available.")
try:
from importlib import util
is_hpu = util.find_spec('habana_frameworks') is not None
if is_hpu:
logger.debug("Confirmed HPU platform is available.")
else:
logger.debug("HPU platform is not available because "
"habana_frameworks is not found.")
except Exception as e:
logger.debug("HPU platform is not available because: %s", str(e))
return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None
def xpu_platform_plugin() -> Optional[str]:
is_xpu = False
logger.debug("Checking if XPU platform is available.")
......@@ -208,7 +191,6 @@ builtin_platform_plugins = {
'tpu': tpu_platform_plugin,
'cuda': cuda_platform_plugin,
'rocm': rocm_platform_plugin,
'hpu': hpu_platform_plugin,
'xpu': xpu_platform_plugin,
'cpu': cpu_platform_plugin,
'neuron': neuron_platform_plugin,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import TYPE_CHECKING, Optional
import torch
from vllm import envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
logger = init_logger(__name__)
class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
device_name: str = "hpu"
device_type: str = "hpu"
dispatch_key: str = "HPU"
ray_device_key: str = "HPU"
dist_backend: str = "hccl"
device_control_env_var: str = "HABANA_VISIBLE_MODULES"
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool,
use_mla: bool) -> str:
logger.info("Using HPUAttention backend.")
return "vllm.attention.backends.hpu_attn.HPUAttentionBackend"
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
@classmethod
def inference_mode(cls):
return torch.no_grad()
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.hpu.set_device(device)
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
scheduler_config = vllm_config.scheduler_config
parallel_config = vllm_config.parallel_config
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_hpu_worker.MultiStepHPUWorker"
if vllm_config.speculative_config is not None:
raise NotImplementedError(
"Speculative decoding is not implemented for HPU")
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
# NOTE(kzawora): default block size for Gaudi should be 128
# smaller sizes still work, but very inefficiently
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 128
if (parallel_config.distributed_executor_backend == 'mp'
and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'):
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD",
None) is not None:
logger.warning("On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork "
"might cause application hangs on exit. Using "
"VLLM_WORKER_MULTIPROC_METHOD=fork anyway, "
"as it was explicitly requested.")
else:
logger.warning(
"On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork "
"might cause application hangs on exit. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"To override that behavior, please set "
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on HPU.")
return False
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa
......@@ -54,7 +54,6 @@ class _Backend(enum.Enum):
FLASHMLA_VLLM_V1 = enum.auto()
FLASHMLA = enum.auto() # Supported by V1
CUTLASS_MLA_VLLM_V1 = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()
......@@ -69,7 +68,6 @@ class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
TPU = enum.auto()
HPU = enum.auto()
XPU = enum.auto()
CPU = enum.auto()
NEURON = enum.auto()
......@@ -154,9 +152,6 @@ class Platform:
def is_tpu(self) -> bool:
return self._enum == PlatformEnum.TPU
def is_hpu(self) -> bool:
return self._enum == PlatformEnum.HPU
def is_xpu(self) -> bool:
return self._enum == PlatformEnum.XPU
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import os
from typing import Any, Callable
import torch
......@@ -75,18 +74,6 @@ def load_general_plugins():
if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
elif current_platform.is_hpu():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
if is_lazy:
torch._dynamo.config.disable = True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
plugins = load_plugins_by_group(group=DEFAULT_PLUGINS_GROUP)
# general plugins, we only need to execute the loaded functions
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import collections
import contextlib
import dataclasses
import functools
import gc
import itertools
import math
import os
import time
from array import array
from enum import Enum, IntEnum
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional, Set, Tuple, Type, TypeVar, Union)
import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc
import torch
import torch.nn as nn
import vllm_hpu_extension.environment as environment
from vllm_hpu_extension.bucketing.common import get_bucketing_context
from vllm_hpu_extension.ops import LoraMask as LoraMask
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
HabanaMemoryProfiler, format_bytes)
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import DeviceConfig, VllmConfig
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SequenceGroupToSample
from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceData, SequenceGroupMetadata,
SequenceOutput)
from vllm.utils import (bind_kv_cache, is_pin_memory_available,
make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
_TYPE_CACHE = {}
# These values are assumed to be zero in several places.
# Use caution when updating them!
_PAD_SLOT_ID = 0
_PAD_BLOCK_ID = 0
LORA_WARMUP_RANK = 8
DUMMY_TOKEN_ID = -1
class PhaseType(Enum):
PREFILL = 'prefill'
PREFIX_PREFILL = 'prefix_prefill'
DECODE = 'decode'
def subtuple(obj: object,
typename: str,
to_copy: List[str],
to_override: Optional[Dict[str, object]] = None):
if obj is None:
return None
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
if type(obj) is dict:
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename,
' '.join(fields))
return _TYPE_CACHE[typename](**values)
def round_up(value: int, k: int):
return (value + k - 1) // k * k
def align_workers(value, op):
group = get_world_group().cpu_group
world_size = torch.distributed.get_world_size()
if world_size <= 1:
return value
value_t = torch.tensor(value, device='cpu')
torch.distributed.all_reduce(value_t, op=op, group=group)
return value_t.item()
def setup_profiler():
schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1)
DEVICE = 'hpu'
activities = [torch.profiler.ProfilerActivity.CPU]
activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE ==
'hpu' else [])
#from habana_frameworks.torch.activity_profiler import DebugActivity
#debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS]
profiler = torch.profiler.profile(
schedule=schedule,
activities=activities,
#debug_activities=debug_activities,
on_trace_ready=torch.profiler.tensorboard_trace_handler('.',
use_gzip=True),
record_shapes=False,
with_stack=True)
return profiler
def pad_list(input, k, v):
input_len = len(input)
target_len = round_up(input_len, k)
padding = target_len - input_len
return input + [v] * padding
def gather_list(input, indices, v):
return [input[i] if i is not None else v for i in indices]
def flatten(in_list):
return list(itertools.chain(*in_list))
def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):
slot_mapping = slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
if is_prompt:
indices = indices.unflatten(0, (-1, block_size))[:, 0]
offsets = None
else:
offsets = torch.fmod(slot_mapping, block_size)
return indices, offsets
def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"):
if module.__class__.__name__.endswith(suffix):
def forward_hook(module, args, output):
htorch.core.mark_step()
return output
module.register_forward_hook(forward_hook)
for child_name, child_module in module.named_children():
modify_decoder_layer(child_module)
class HpuModelAdapter:
def __init__(self, model, vllm_config):
self.model = model
self.sampler = get_sampler()
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.dtype = vllm_config.model_config.dtype
enforce_eager = vllm_config.model_config.enforce_eager
if not htorch.utils.internal.is_lazy() and not enforce_eager:
if os.getenv('VLLM_REGIONAL_COMPILATION',
'true').lower() == 'true':
self.regional_compilation_layers_list = [
RMSNorm, VocabParallelEmbedding
]
self._regional_compilation(self.model)
else:
self.model = torch.compile(self.model,
backend='hpu_backend',
dynamic=False)
def _regional_compilation(self,
module,
parent_module=None,
module_name=None):
if isinstance(module, torch.nn.ModuleList):
for children_name, children_module in module.named_children():
self._compile_region(module, children_name, children_module)
elif any(
isinstance(module, layer)
for layer in self.regional_compilation_layers_list):
self._compile_region(parent_module, module_name, module)
else:
for children_name, children_module in module.named_children():
self._regional_compilation(children_module, module,
children_name)
def _compile_region(self, model, name, module):
module = torch.compile(module, backend='hpu_backend', dynamic=False)
setattr(model, name, module)
def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
dtype):
if (attn_metadata is None
or (self.prefill_use_fusedsdpa \
and attn_metadata.block_list is None)
or not attn_metadata.is_prompt):
return attn_metadata
prefill_metadata = attn_metadata
seq_lens_t = prefill_metadata.seq_lens_tensor
context_lens_t = prefill_metadata.context_lens_tensor
query_lens_t = seq_lens_t - context_lens_t
block_list = attn_metadata.block_list
max_context_len = (block_list.size(-1) //
batch_size if block_list is not None else 0)
max_context_len = max_context_len * self.block_size
past_mask = torch.arange(0,
max_context_len,
dtype=torch.int32,
device=device)
past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge(
context_lens_t.view(-1, 1)).view(batch_size, 1, -1).expand(
batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1))
len_mask = (torch.arange(0, seq_len, device=device,
dtype=torch.int32).view(1, seq_len).ge(
query_lens_t.unsqueeze(-1)).view(
batch_size, 1, 1, seq_len))
causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len),
device=device,
dtype=torch.bool),
diagonal=1)
mask = causal_mask.logical_or(len_mask)
mask = torch.concat((past_mask, mask), dim=-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))
attn_metadata = prefill_metadata._replace(attn_bias=attn_bias)
return attn_metadata
def _set_block_mapping(self, metadata, batch_size, device, dtype):
mask = torch.arange(0,
self.block_size,
device=device,
dtype=torch.int32).unsqueeze(0)
mask = mask >= metadata.block_usage.unsqueeze(-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))
if os.environ.get('VLLM_USE_FAKE_HPU',
'0') == '0' and htorch.utils.internal.is_lazy():
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
num_classes=batch_size)
else:
# Unfortunately one_hot on CPU/torch.compile mode/eager mode
# doesn't handle out of bounds classes so we need to convert
# all negative values to 0 (block_mapping) or bs (block_groups)
block_groups = metadata.block_groups.to(torch.long)
block_mapping = torch.nn.functional.relu(block_groups)
block_mapping = torch.nn.functional.one_hot(block_mapping,
num_classes=batch_size)
oob_values = block_groups.lt(0)
block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0)
block_groups.masked_fill_(oob_values, batch_size)
metadata = metadata._replace(block_groups=block_groups)
block_mapping = block_mapping.to(dtype)
metadata = metadata._replace(block_mapping=block_mapping,
attn_bias=attn_bias)
return metadata
def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
dtype):
if attn_metadata.is_prompt:
meta = attn_metadata
attn_metadata = self._set_attn_bias(meta, batch_size, seq_len,
device, dtype)
else:
meta = attn_metadata
attn_metadata = self._set_block_mapping(meta, batch_size, device,
dtype)
return attn_metadata
def forward(self, *args, **kwargs):
kwargs = kwargs.copy()
selected_token_indices = kwargs.pop('selected_token_indices')
if 'warmup_mode' in kwargs:
kwargs.pop('warmup_mode')
virtual_engine = 0
if 'virtual_engine' in kwargs:
virtual_engine = kwargs.pop('virtual_engine')
input_ids = kwargs['input_ids']
attn_metadata = self._update_metadata(kwargs.pop('attn_metadata'),
input_ids.size(0),
input_ids.size(1),
input_ids.device, self.dtype)
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
with set_forward_context(attn_metadata, self.vllm_config,
virtual_engine):
hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = hidden_states.index_select(0,
selected_token_indices)
return hidden_states
def compute_logits(self, *args, **kwargs):
return self.model.compute_logits(*args, **kwargs)
def sample(self, *args, **kwargs):
return self.sampler(*args, **kwargs)
class PreparePromptMetadata(NamedTuple):
input_tokens: torch.Tensor
input_positions: List[List[int]]
attn_metadata: Optional[AttentionMetadata]
seq_lens: List[int]
query_lens: List[int]
lora_index_mapping: List[List[int]]
lora_prompt_mapping: List[List[int]]
lora_requests: Set[LoRARequest]
multi_modal_kwargs: Optional[Dict[str, BatchedTensorInputs]]
slot_mapping: List[List[int]]
lora_ids: List[int]
@classmethod
def empty(cls):
return PreparePromptMetadata(input_tokens=[],
input_positions=[],
attn_metadata=None,
seq_lens=[],
query_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
multi_modal_kwargs=None,
slot_mapping=[],
lora_ids=[])
class PrepareDecodeMetadata(NamedTuple):
input_tokens: torch.Tensor
input_positions: List[List[int]]
attn_metadata: Optional[AttentionMetadata]
lora_index_mapping: List[List[int]]
lora_prompt_mapping: List[List[int]]
lora_requests: Set[LoRARequest]
slot_mapping: List[List[int]]
lora_ids: List[int]
@classmethod
def empty(cls):
return PrepareDecodeMetadata(input_tokens=[],
input_positions=[],
attn_metadata=None,
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
lora_ids=[])
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
TModelInputForHPU = TypeVar('TModelInputForHPU', bound="ModelInputForHPU")
@dataclasses.dataclass(frozen=True)
class ModelInputForHPU(ModelRunnerInputBase):
"""
This base class contains metadata needed for the base model forward pass
but not metadata for possible additional steps, e.g., sampling. Model
runners that run additional steps should subclass this method to add
additional fields.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
real_batch_size: Optional[int] = None
batch_size_padded: Optional[int] = None
virtual_engine: int = 0
lora_ids: Optional[List[int]] = None
async_callback: Optional[Callable] = None
is_first_multi_step: bool = True
is_last_step: bool = True
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"real_batch_size": self.real_batch_size,
"batch_size_padded": self.batch_size_padded,
"virtual_engine": self.virtual_engine,
"lora_ids": self.lora_ids,
"is_first_multi_step": self.is_first_multi_step,
"is_last_step": self.is_last_step,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type[TModelInputForHPU],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> TModelInputForHPU:
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
@dataclasses.dataclass(frozen=True)
class ModelInputForHPUWithSamplingMetadata(ModelInputForHPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
# Used for speculative decoding. We do not broadcast it because it is only
# used by the driver worker.
is_prompt: Optional[bool] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"lora_ids": self.lora_ids,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForHPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
# FIXME(kzawora): this fails for whatever reason - why?
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
"""
Helper class for shared methods between GPU model runners.
"""
_model_input_cls: Type[TModelInputForHPU]
def __init__(
self,
vllm_config: VllmConfig,
is_driver_worker: bool = False,
return_hidden_states: bool = False,
):
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
environment.set_model_config(self.model_config)
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
self.sliding_window = (self.model_config.get_sliding_window()
if self.model_config is not None else None)
self.device_config = (self.device_config if self.device_config
is not None else DeviceConfig())
self.device = self.device_config.device
self.enforce_eager = self.model_config.enforce_eager
self.max_num_seqs = self.scheduler_config.max_num_seqs
# NOTE(kzawora): Change that to scheduler_config.max_num_prefill_seqs
# once padding-aware scheduling gets merged
self.max_num_prefill_seqs = 64
self.max_model_len = self.scheduler_config.max_model_len
self.max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
self.block_size = self.cache_config.block_size
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = self.cache_config.cache_dtype
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Lazy initialization
self.lora_manager: LRUCacheWorkerLoRAManager = None
self.model: torch.nn.Module = None
self.inc_initialized_successfully = False
# Profiler stats
self.profiler = HabanaHighLevelProfiler()
self.profiler_counter_helper = HabanaProfilerCounterHelper()
self.seen_configs: set = set()
self._mem_margin: Optional[int] = None
HPUBucketingContext = get_bucketing_context()
self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs,
self.max_num_prefill_seqs,
self.block_size,
self.max_num_batched_tokens,
False, self.max_model_len)
self.graphed_buckets: Set[Any] = set()
self._set_gc_threshold()
if self.vllm_config.cache_config.enable_prefix_caching:
os.environ.setdefault("VLLM_CONTIGUOUS_PA", "False")
assert os.environ.get(
"VLLM_CONTIGUOUS_PA",
"").lower() != "true", "Contiguous PA doesn't support APC"
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
# For multi-step scheduling
self.cached_step_outputs: List[torch.Tensor] = []
# For delayed sampling
self.cached_step_inputs: List[
ModelInputForHPUWithSamplingMetadata] = []
def _set_gc_threshold(self) -> None:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
# for comprehensive description of gc generations.
# We can either use VLLM_GC_THR_GEN[0-2] (this has higher priority)
# to set particular generation threshold or use simpler
# VLLM_GC_THR_MULTIPLIER to multiply default values.
default_gc_thrs = list(gc.get_threshold())
requested_gc_thrs = [0] * len(default_gc_thrs)
for i in range(len(default_gc_thrs)):
requested_gc_thrs[i] = int(
os.environ.get(f'VLLM_GC_THR_GEN{i}', default_gc_thrs[i]))
if requested_gc_thrs == default_gc_thrs:
gc_thr_multiplier = int(os.environ.get('VLLM_GC_THR_MULTIPLIER',
2))
requested_gc_thrs = [
t * gc_thr_multiplier for t in default_gc_thrs
]
gc.set_threshold(*requested_gc_thrs)
self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP',
'false').lower() == 'true'
def load_model(self) -> None:
import habana_frameworks.torch.core as htcore
if self.model_config.quantization == 'inc' or \
self.model_config.quantization == 'fp8':
htcore.hpu_set_env()
with HabanaMemoryProfiler() as m:
with HabanaMemoryProfiler() as m_getmodel:
self.model = get_model(vllm_config=self.vllm_config)
msg = ("Pre-loading model weights on "
f"{next(self.model.parameters()).device} "
f"took {m_getmodel.get_summary_string()}")
logger.info(msg)
if self.lora_config:
assert hasattr(self.model, "embedding_modules"
), "Model does not have embedding_modules"
assert hasattr(
self.model, "embedding_padding_modules"
), "Model does not have embedding_padding_modules"
assert not self.lora_config.bias_enabled, \
"Bias support in LoRA is not enabled in HPU yet."
assert not self.lora_config.fully_sharded_loras, \
"Fully sharded LoRAs is not enabled in HPU yet."
# Use get_text_config() in case of multimodal models
text_config = self.model_config.hf_config.get_text_config()
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=text_config.
max_position_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
if self.model_config.quantization == 'inc':
logger.info("Preparing model with INC..")
with HabanaMemoryProfiler() as m_inc:
from neural_compressor.torch.quantization import (
FP8Config, convert, prepare)
config = FP8Config.from_json_file(
os.getenv("QUANT_CONFIG", ""))
if config.measure:
self.model = prepare(self.model, config)
elif config.quantize:
self.model = convert(self.model, config)
htcore.hpu_initialize(self.model,
mark_only_scales_as_const=True)
self.inc_initialized_successfully = True
logger.info("Preparing model with INC took %s",
m_inc.get_summary_string())
else:
self.model = self.model.to("hpu")
htcore.mark_step()
modify_decoder_layer(self.model)
torch.hpu.synchronize()
with HabanaMemoryProfiler() as m_wrap:
self.model = _maybe_wrap_in_hpu_graph(
self.model, vllm_config=self.vllm_config)
msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}"
logger.info(msg)
self.model_memory_usage = m.consumed_device_memory
msg = f"Loading model weights took in total {m.get_summary_string()}"
logger.info(msg)
def _add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list, real_batch_size, batch_size_padded
def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
return htorch.hpu.wrap_in_hpu_graph(
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(
*args, **kwargs)
def get_model(self) -> nn.Module:
return self.model
def _use_graphs(self, batch_size, seq_len, is_prompt):
if self.enforce_eager:
return False
if self.skip_warmup:
return True
return (batch_size, seq_len, is_prompt) in self.graphed_buckets
def _is_valid_bucket(self, bucket):
return bucket[0] * bucket[1] <= self.max_num_batched_tokens
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> PreparePromptMetadata:
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
lora_index_mapping: List[List[int]] = []
lora_prompt_mapping: List[List[int]] = []
lora_requests: Set[LoRARequest] = set()
seq_lens: List[int] = []
context_lens: List[int] = []
query_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty()
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
token_chunk_size = seq_group_metadata.token_chunk_size
seq_data = seq_group_metadata.seq_data[seq_id]
context_len = seq_data.get_num_computed_tokens()
# We should use get_len here because in case of preemption
# it contains output tokens.
seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
seq_lens.append(seq_len)
# NOTE: This only works for oooooooxxx style attention.
if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window
context_len = len(computed_block_nums) * self.block_size
if context_len == seq_len \
and self.vllm_config.cache_config.enable_prefix_caching:
# Fully cached prompt - compute only last token
context_len = context_len - 1
prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None:
# Prefill has chunked before.
block_table = seq_group_metadata.block_tables[seq_id]
prefix_block_tables.append(block_table)
else:
# The first prefill.
prefix_block_tables.append([])
else:
prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert context_len == 0
# actual prompt lens
context_lens.append(context_len)
query_lens.append(seq_len - context_len)
input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.append(list(range(context_len, seq_len)))
mm_kwargs = seq_group_metadata.multi_modal_data
if mm_kwargs:
multi_modal_kwargs_list.append(mm_kwargs)
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.append([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
slot_mapping.append([])
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
assert context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, seq_len - self.sliding_window)
for i in range(context_len, seq_len):
if i < start_idx:
slot_mapping[-1].append(_PAD_SLOT_ID)
continue
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
max_query_len = max(query_lens)
sum_query_len = sum(query_lens)
real_num_seqs = len(query_lens)
assert max_query_len > 0
max_prompt_len = max(
self.bucketing_ctx.get_padded_prompt_seq_len(max_query_len),
self.block_size)
lora_ids: List[int] = []
for seq_group_metadata, context_len in zip(seq_group_metadata_list,
context_lens):
lora_id = seq_group_metadata.lora_int_id
lora_ids.append(lora_id)
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * max_prompt_len
lora_prompt_mapping.extend(
[lora_id] *
(max_prompt_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if any(context_lens):
assert not self.scheduler_config.chunked_prefill_enabled
# prefix caching
max_num_block = max(len(bt) for bt in prefix_block_tables)
prefix_block_list = list(
itertools.chain.from_iterable(
bt if len(bt) == max_num_block else bt +
([_PAD_BLOCK_ID] * (max_num_block - len(bt)))
for bt in prefix_block_tables))
pad_len = len(prefix_block_list)
prefix_block_list = pad_list(prefix_block_list, pad_len,
_PAD_BLOCK_ID)
prefix_block_list_tensor = torch.tensor(prefix_block_list,
dtype=torch.long,
device=self.device)
else:
prefix_block_list_tensor = None
input_tokens = make_tensor_with_pad(input_tokens,
max_len=max_prompt_len,
pad=0,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_len=max_prompt_len,
pad=0,
dtype=torch.long,
device=self.device)
slot_mapping = make_tensor_with_pad(slot_mapping,
max_len=max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.long,
device=self.device)
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.long,
device=self.device)
block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, True)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
block_list=prefix_block_list_tensor,
block_mapping=None,
block_usage=None,
block_indices=block_indices,
block_offsets=block_offsets,
block_groups=None,
attn_bias=None,
seq_lens_tensor=seq_lens_tensor,
context_lens_tensor=context_lens_tensor,
num_prefills=real_num_seqs,
num_prefill_tokens=sum_query_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=
None, # FIXME(kzawora): multi-modality will not work here
enable_kv_scales_calculation=False,
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return PreparePromptMetadata(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
slot_mapping=slot_mapping,
lora_ids=lora_ids)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
output=None,
) -> PrepareDecodeMetadata:
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
lora_index_mapping: List[List[int]] = []
lora_prompt_mapping: List[List[int]] = []
lora_requests: Set[LoRARequest] = set()
if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
lora_ids: List[int] = []
dummy_slots = itertools.cycle(
range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size))
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
seq_ids = list(seq_group_metadata.seq_data.keys())
lora_id = seq_group_metadata.lora_int_id
lora_ids.append(lora_id)
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
if output is None:
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
num_fully_occupied_blocks = position // self.block_size
block_table = block_table[:num_fully_occupied_blocks + 1]
if len(block_table) == 0:
block_number = _PAD_BLOCK_ID
else:
block_number = block_table[position // self.block_size]
if block_number == _PAD_BLOCK_ID:
slot = next(dummy_slots)
else:
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
lora_index_mapping.append(lora_id)
lora_prompt_mapping.append(lora_id)
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
if output is None:
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
else:
real_batch_size = len(seq_group_metadata_list)
input_tokens = output[:real_batch_size]
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
num_decode_tokens = sum(seq_lens)
last_block_usage = [
slot[0] % self.block_size + 1 for slot in slot_mapping
]
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
block_usage = [[self.block_size] * (len(bt) - 1) + [lbu]
for bt, lbu in zip(block_tables, last_block_usage)
if bt]
block_list = flatten(block_tables)
block_groups = flatten(block_groups)
block_usage = flatten(block_usage)
assert len(block_list) == len(block_groups)
assert len(block_list) == len(block_usage)
padding_fn = None
if self.use_contiguous_pa:
block_bucket_size = max(max(block_list) + 1, len(block_list))
block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
block_bucket_size)
indices: List[Any]
indices = [None] * block_bucket_size
for i, bid in enumerate(block_list):
indices[bid] = i
padding_fn = lambda tensor, pad_value: gather_list(
tensor, indices, pad_value)
else:
block_bucket_size = \
self.bucketing_ctx.get_padded_decode_num_blocks(
len(block_list))
padding_fn = lambda tensor, pad_value: pad_list(
tensor, block_bucket_size, pad_value)
block_list = padding_fn(block_list, _PAD_BLOCK_ID)
block_groups = padding_fn(block_groups, -1)
block_usage = padding_fn(block_usage, 1)
block_list = torch.tensor(block_list,
dtype=torch.int,
device=self.device)
block_groups = torch.tensor(block_groups,
dtype=torch.int,
device=self.device)
block_usage = torch.tensor(block_usage,
dtype=self.model_config.dtype,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, False)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
block_list=block_list,
block_mapping=None,
block_usage=block_usage,
block_indices=block_indices,
block_offsets=block_offsets,
block_groups=block_groups,
attn_bias=None,
seq_lens_tensor=None,
context_lens_tensor=None,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
)
return PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
slot_mapping=slot_mapping,
lora_ids=lora_ids)
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[TModelInputForHPU, SamplingMetadata]:
if len(seq_group_metadata_list) == 0:
return self._model_input_cls(), None
input_tokens = None
input_positions = None
lora_mapping = None
lora_requests = None
multi_modal_kwargs = None
batch_type = None
seq_lens = None
query_lens = None
real_batch_size = None
batch_size_padded = None
self.event_start = self.profiler.get_timestamp_us()
is_prompt = seq_group_metadata_list[0].is_prompt
base_event_name = 'prompt' if is_prompt else 'decode'
self.profiler.start('internal', base_event_name)
seq_group_metadata_list, real_batch_size, batch_size_padded = (
self._add_dummy_seq(seq_group_metadata_list, is_prompt))
prefill_reqs = []
decode_reqs = []
for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_reqs.append(seq_group_meta)
else:
decode_reqs.append(seq_group_meta)
# Prepare input tensors.
(
input_tokens,
input_positions,
prefill_attn_metadata,
seq_lens,
query_lens,
lora_index_mapping,
lora_prompt_mapping,
lora_requests,
multi_modal_kwargs,
slot_mapping,
lora_ids,
) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
decode_input_positions,
decode_attn_metadata,
decode_lora_index_mapping,
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
decode_lora_ids,
) = self._prepare_decode(decode_reqs)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
seq_lens, query_lens,
self.device,
self.pin_memory)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(seq_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
# NOTE(kzawora): Here we diverge from GPU code - we don't
# support mixed batches, so we either use decode or prefill
# inputs, without coalescing.
assert (num_prefills == 0 and num_decode_tokens > 0) or (
num_prefills > 0
and num_decode_tokens == 0), "HPU does not support mixed batches!"
if num_decode_tokens > 0:
input_tokens = decode_input_tokens
input_positions = decode_input_positions
slot_mapping = decode_slot_mapping
lora_index_mapping = decode_lora_index_mapping
lora_prompt_mapping = decode_lora_prompt_mapping
lora_requests = decode_lora_requests
lora_ids = decode_lora_ids
# FIXME: We need to adjust selected_token_indices to accommodate
# for padding
max_len = input_tokens.size(1)
paddings = [max_len - q for q in query_lens]
paddings = [0] + paddings[:-1]
paddings = list(itertools.accumulate(paddings))
paddings_prompt_logprobs = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
if seq_group_metadata.sampling_params.prompt_logprobs is not None \
and seq_group_metadata.is_prompt:
paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i])
paddings = torch.tensor(
paddings_prompt_logprobs if paddings_prompt_logprobs else paddings,
dtype=sampling_metadata.selected_token_indices.dtype,
device=sampling_metadata.selected_token_indices.device)
sampling_metadata.selected_token_indices.add_(paddings)
if self.lora_config:
lora_mapping = LoRAMapping(
**dict(index_mapping=lora_index_mapping,
prompt_mapping=lora_prompt_mapping,
is_prefill=(num_prefills > 0)))
else:
lora_mapping = None
if (prefill_attn_metadata is not None
and decode_attn_metadata is not None):
batch_type = BatchType.MIXED
raise NotImplementedError("Mixed batch is not supported on HPU")
elif prefill_attn_metadata is not None:
batch_type = BatchType.PREFILL
else:
batch_type = BatchType.DECODE
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"selected_token_indices": sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
"batch_type": batch_type,
"seq_lens": seq_lens,
"query_lens": query_lens
}
if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else:
assert decode_attn_metadata is not None
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
attn_metadata = prefill_attn_metadata if \
prefill_attn_metadata is not None else decode_attn_metadata
return self._model_input_cls(input_tokens=input_tokens,
seq_lens=seq_lens,
query_lens=query_lens,
input_positions=input_positions,
attn_metadata=attn_metadata,
lora_requests=lora_requests,
lora_mapping=lora_mapping,
multi_modal_kwargs=multi_modal_kwargs,
real_batch_size=real_batch_size,
batch_size_padded=batch_size_padded,
lora_ids=lora_ids), \
sampling_metadata
def _seq_len(self, attn_metadata):
if attn_metadata.num_prefills != 0:
return attn_metadata.slot_mapping.size(1)
else:
return attn_metadata.block_list.numel()
def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.
# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias',
'seq_lens_tensor',
'context_lens_tensor',
'block_list',
'block_mapping',
'block_usage',
'slot_mapping',
'is_prompt',
'block_indices',
'block_offsets',
'block_groups',
])
return attention_metadata
def create_dummy_seq_group_metadata(self,
group_id,
seq_len,
is_prompt,
lora_request=None):
sampling_params = SamplingParams(temperature=0)
num_blocks = math.ceil(seq_len / self.block_size)
seq_len = max(seq_len, 1)
if is_prompt:
input_len = seq_len
output_len = 0
block_tables = None
else:
input_len = seq_len - 1
output_len = 1
block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks}
prompt_token_ids = [0] * input_len
output_token_ids = [1] * output_len
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
seq_data = SequenceData(prompt_token_ids_array)
seq_data.output_token_ids = output_token_ids
return SequenceGroupMetadata(request_id=str(group_id),
is_prompt=(output_len == 0),
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=block_tables,
lora_request=lora_request)
def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
bind_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
[kv_caches])
_, max_seq_len = self.bucketing_ctx.get_max_prompt_shape()
max_batch_size = min(self.max_num_seqs,
self.max_num_batched_tokens // max_seq_len)
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
False, True)
return
def warmup_scenario(self,
batch_size,
seq_len,
is_prompt,
kv_caches,
is_pt_profiler_run=False,
is_lora_profile_run=False) -> None:
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
scenario_name = ("warmup_"
f"{'prompt' if is_prompt else 'decode'}_"
f"bs{batch_size}_"
f"seq{seq_len}_"
f"graphs{'T' if use_graphs else 'F'}")
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
dummy_lora_requests: List[LoRARequest] = []
dummy_lora_requests_per_seq: List[LoRARequest] = []
if self.lora_config and is_lora_profile_run:
assert self.lora_manager is not None
with self.lora_manager.dummy_lora_cache():
for idx in range(self.lora_config.max_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(batch_size)
]
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs or is_pt_profiler_run else 1
if is_prompt:
seqs = [
self.create_dummy_seq_group_metadata(
i,
seq_len,
is_prompt,
lora_request=dummy_lora_requests_per_seq[i]
if dummy_lora_requests_per_seq else None)
for i in range(batch_size)
]
else:
# FIXME: seq_len is actually number of blocks
blocks = [seq_len // batch_size for _ in range(batch_size)]
blocks[0] += seq_len % batch_size
seqs = [
self.create_dummy_seq_group_metadata(
i,
b * self.block_size - 1,
is_prompt,
lora_request=dummy_lora_requests_per_seq[i]
if dummy_lora_requests_per_seq else None)
for i, b in enumerate(blocks)
]
torch.hpu.synchronize()
profiler = None
if is_pt_profiler_run and self.is_driver_worker:
profiler = setup_profiler()
profiler.start()
for _ in range(times):
inputs = self.prepare_model_input(seqs)
is_single_step = \
self.vllm_config.scheduler_config.num_scheduler_steps == 1
if is_prompt or is_single_step:
self.execute_model(inputs, None, warmup_mode=True)
else: # decode with multi-step
inputs = dataclasses.replace(inputs,
is_first_multi_step=True,
is_last_step=False)
self.execute_model(inputs,
None,
warmup_mode=True,
num_steps=2,
seqs=seqs)
inputs = dataclasses.replace(inputs,
is_first_multi_step=False,
is_last_step=True)
self.execute_model(inputs,
None,
warmup_mode=True,
num_steps=2,
seqs=seqs)
torch.hpu.synchronize()
if profiler:
profiler.step()
if profiler:
profiler.stop()
self.profiler.end()
gc.collect()
def remove_all_loras(self):
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.remove_all_adapters()
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()
def log_warmup(self, phase, i, max_i, batch_size, seq_len):
free_mem = format_bytes(
HabanaMemoryProfiler.current_free_device_memory())
dim = "num_blocks"
if phase == "Prompt":
dim = "seq_len"
msg = (f"[Warmup][{phase}][{i+1}/{max_i}] "
f"batch_size:{batch_size} "
f"{dim}:{seq_len} "
f"free_mem:{free_mem}")
logger.info(msg)
def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
len(buckets), batch_size, seq_len)
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
def warmup_graphs(self,
strategy,
buckets,
is_prompt,
kv_caches,
available_mem,
starting_mem=0,
total_batch_seq=0.001):
total_mem = starting_mem
idx = 0
phase = f'Graph/{"Prompt" if is_prompt else "Decode"}'
num_candidates = len(buckets)
ordering : Union[Callable[[Any], Tuple[Any, Any]], \
Callable[[Any], Tuple[Any, Any, Any]]]
if strategy == 'min_tokens':
ordering = lambda b: (b[0] * b[1], b[1], b[0])
elif strategy == 'max_bs':
ordering = lambda b: (-b[0], b[1])
else:
raise NotImplementedError(
f'Unsupported graph allocation strategy: {strategy}')
buckets = list(sorted(buckets, key=ordering))
captured_all = True
for idx, (batch_size, seq_len) in enumerate(buckets):
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len if is_prompt else batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
if mem_estimate >= available_mem:
captured_all = False
continue
graphed_bucket = (batch_size, seq_len, is_prompt)
if graphed_bucket in self.graphed_buckets:
continue
self.graphed_buckets.add(graphed_bucket)
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
used_mem = align_workers(mem_prof.consumed_device_memory,
torch.distributed.ReduceOp.MAX)
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
return total_mem, total_batch_seq, captured_all
def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):
num_candidates = len(buckets)
phase = f'Graph/{"Prompt" if is_prompt else "Decode"}'
graphed = list(c[:2] for c in self.graphed_buckets
if c[2] == is_prompt)
if num_candidates == 0:
num_candidates = 1
msg = (f'{phase} captured:{len(graphed)} '
f'({100 * len(graphed) / num_candidates:.1f}%) '
f'used_mem:{format_bytes(total_mem)} '
f'buckets:{sorted(list(graphed))}')
logger.info(msg)
@torch.inference_mode()
def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
max_blocks = kv_caches[0][0].size(0)
self.bucketing_ctx.generate_decode_buckets(max_blocks)
if profile := os.environ.get('VLLM_PT_PROFILE', None):
phase, bs, seq_len, graph = profile.split('_')
is_prompt = phase == 'prompt'
graphs = graph == 't'
if graphs:
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
True)
raise AssertionError("Finished profiling")
if not htorch.utils.internal.is_lazy() and not self.enforce_eager:
cache_size_limit = 1 + 3 * (
len(self.bucketing_ctx.prompt_buckets) +
len(self.bucketing_ctx.decode_buckets))
torch._dynamo.config.cache_size_limit = max(
cache_size_limit, torch._dynamo.config.cache_size_limit)
# Multiply by 8 to follow the original default ratio between
# the cache_size_limit and accumulated_cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = max(
cache_size_limit * 8,
torch._dynamo.config.accumulated_cache_size_limit)
if self.skip_warmup:
logger.info("Skipping warmup...")
return
self.profiler.start('internal', 'warmup')
start_mem = HabanaMemoryProfiler.current_device_memory_usage()
start_time = time.perf_counter()
compile_only_mode_context = functools.partial(bc.env_setting,
"PT_COMPILE_ONLY_MODE",
True)
can_use_compile_only_mode = True
try:
with compile_only_mode_context():
pass
logger.debug("Using PT_COMPILE_ONLY_MODE.")
except KeyError:
can_use_compile_only_mode = False
logger.warning('Cannot use PT_COMPILE_ONLY_MODE. '
'Warmup time will be negatively impacted. '
'Please update Gaudi Software Suite.')
with compile_only_mode_context(
) if can_use_compile_only_mode else contextlib.nullcontext():
self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True,
kv_caches)
self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False,
kv_caches)
if not self.enforce_eager and htorch.utils.internal.is_lazy():
assert self.mem_margin is not None, \
("HabanaWorker.determine_num_available_blocks needs "
"to be called before warming up the model.")
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_margin
graph_free_mem = align_workers(graph_free_mem,
torch.distributed.ReduceOp.MIN)
prompt_graph_mem_ratio = float(
os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.3'))
prompt_available_memory = (prompt_graph_mem_ratio *
graph_free_mem)
decode_available_memory = (graph_free_mem -
prompt_available_memory)
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})")
logger.info(msg)
prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY',
'min_tokens')
decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY',
'max_bs')
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
self.warmup_graphs(
prompt_strategy, self.bucketing_ctx.prompt_buckets,
True, kv_caches, prompt_available_memory)
mem_post_decode, decode_batch_seq, decode_captured_all = \
self.warmup_graphs(
decode_strategy, self.bucketing_ctx.decode_buckets,
False, kv_caches, decode_available_memory)
# Not all prompt buckets were captured, but all decode buckets
# were captured and we have some free graph-allocated space
# left. Let's try to use it for capturing more prompt buckets.
if (mem_post_decode + mem_post_prompt < graph_free_mem
and not prompt_captured_all and decode_captured_all):
mem_post_prompt, _, prompt_captured_all = (
self.warmup_graphs(
prompt_strategy, self.bucketing_ctx.prompt_buckets,
True, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_prompt, prompt_batch_seq))
# Not all decode buckets were captured, but all prompt buckets
# were captured and we have some free graph-allocated space
# left. Let's try to use it for capturing more decode buckets.
if mem_post_decode + mem_post_prompt < graph_free_mem \
and not decode_captured_all \
and prompt_captured_all:
mem_post_decode, _, _ = self.warmup_graphs(
decode_strategy, self.bucketing_ctx.decode_buckets,
False, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_decode, decode_batch_seq)
self.log_graph_warmup_summary(
self.bucketing_ctx.prompt_buckets, True, mem_post_prompt)
self.log_graph_warmup_summary(
self.bucketing_ctx.decode_buckets, False, mem_post_decode)
end_time = time.perf_counter()
end_mem = HabanaMemoryProfiler.current_device_memory_usage()
elapsed_time = end_time - start_time
msg = (
f"Warmup finished in {elapsed_time:.0f} secs, "
f"allocated {format_bytes(end_mem - start_mem)} of device memory")
logger.info(msg)
self.profiler.end()
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
@property
def mem_margin(self) -> Optional[int]:
return self._mem_margin
@mem_margin.setter
def mem_margin(self, value):
self._mem_margin = value
def _maybe_wrap_in_hpu_graph(*args, **kwargs):
return htorch.hpu.wrap_in_hpu_graph(
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs)
class HabanaProfilerCounterHelper:
def __init__(self):
self.niter = 0
self.average_real_throughput = None
self.logged_once = False
self.real_seq_lens = []
self.prompt_seq_lens = []
def capture_seq_group_metadata_stats(self, seq_group_metadata_list):
self.real_seq_lens = [
len(seq_data.prompt_token_ids) + len(seq_data.output_token_ids)
for seq_group_metadata in seq_group_metadata_list
for seq_data in seq_group_metadata.seq_data.values()
]
self.prompt_seq_lens = [
len(seq_data.prompt_token_ids)
for seq_group_metadata in seq_group_metadata_list
for seq_data in seq_group_metadata.seq_data.values()
]
def get_counter_dict(self, cache_config, duration, seq_len,
batch_size_padded, real_batch_size, is_prompt):
throughput = batch_size_padded / (duration / 1e6)
throughput_effective = real_batch_size / (duration / 1e6)
real_max_seq_len = max(self.real_seq_lens)
real_num_tokens = sum(self.real_seq_lens)
padded_num_tokens = batch_size_padded * seq_len
batch_token_utilization = real_num_tokens / padded_num_tokens
if self.average_real_throughput is None:
self.average_real_throughput = throughput_effective
else: # https://www.heikohoffmann.de/htmlthesis/node134.html
self.average_real_throughput = self.average_real_throughput + 1 / (
self.niter + 1) * (throughput_effective -
self.average_real_throughput)
phase = "prompt" if is_prompt else "decode"
counters = {
f'{phase}_bucket_batch_size': batch_size_padded,
f'{phase}_batch_size': real_batch_size,
f'{phase}_bucket_seq_len': seq_len,
f'{phase}_seq_len': real_max_seq_len,
f'{phase}_bucket_gen_throughput': throughput,
f'{phase}_real_gen_throughput': throughput_effective,
f'{phase}_batch_token_utilization': batch_token_utilization,
'average_real_throughput': self.average_real_throughput,
'engine_iteration': self.niter,
}
self.niter += 1
if is_prompt:
prompt_bucket_in_throughput = (seq_len * batch_size_padded) / (
duration / 1e6)
prompt_real_in_throughput = sum(
self.prompt_seq_lens) / (duration / 1e6)
counters[
f'{phase}_bucket_in_throughput'] = prompt_bucket_in_throughput
counters[f'{phase}_real_in_throughput'] = prompt_real_in_throughput
# KV cache might not be created yet (e.g. for profiling run)
if cache_config.num_gpu_blocks is not None and \
cache_config.num_gpu_blocks != 0:
cache_num_blocks_used = [
math.ceil(sl / cache_config.block_size)
for sl in self.real_seq_lens
]
cache_total_num_blocks_used = sum(cache_num_blocks_used)
num_cache_blocks = cache_config.num_gpu_blocks
cache_total_num_free_blocks = \
num_cache_blocks - cache_total_num_blocks_used
cache_computed_utilization = \
cache_total_num_blocks_used / num_cache_blocks
max_blocks_per_seq = math.ceil(seq_len / cache_config.block_size)
batch_block_utilization = cache_total_num_blocks_used / (
batch_size_padded * max_blocks_per_seq)
counters['cache_num_blocks_used'] = cache_total_num_blocks_used
counters['cache_num_free_blocks'] = cache_total_num_free_blocks
counters['cache_computed_utilization'] = cache_computed_utilization
counters[
f'{phase}_batch_block_utilization'] = batch_block_utilization
if not self.logged_once:
counters['const_cache_num_blocks'] = cache_config.num_gpu_blocks
counters[
'const_gpu_memory_utilization'] = \
cache_config.gpu_memory_utilization
counters['const_block_size'] = cache_config.block_size
self.logged_once = True
return counters
def unwrap_model(model):
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
return unwrap_model(model._orig_mod)
else:
model = list(vars(model)['_modules'].values())[0]
modules = list(vars(model)['_modules'].values())
return modules
class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
"""
GPU model runner with sampling step.
"""
_model_input_cls: Type[ModelInputForHPUWithSamplingMetadata] = (
ModelInputForHPUWithSamplingMetadata)
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForHPUWithSamplingMetadata:
return (
ModelInputForHPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
@torch.inference_mode()
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForHPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
The result tensors and data structure also batches input in prefill
-> decode order. For example,
- input_tokens[:num_prefill_tokens] contains prefill tokens.
- input_tokens[num_prefill_tokens:] contains decode tokens.
If cuda graph is required, this API automatically pads inputs.
"""
with self.profiler.record_event('internal', 'prepare_input_tensors'):
assert seq_group_metadata_list is not None
if self.profiler.enabled:
self.profiler_counter_helper.capture_seq_group_metadata_stats(
seq_group_metadata_list=seq_group_metadata_list)
model_input, sampling_metadata = self.prepare_input_tensors(
seq_group_metadata_list)
assert model_input.attn_metadata is not None
is_prompt = model_input.attn_metadata.is_prompt
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
is_prompt=is_prompt,
virtual_engine=virtual_engine)
def finish_measurements(self):
from neural_compressor.torch.quantization import finalize_calibration
finalize_calibration(self.model.model)
def _num_blocks(self, attn_metadata):
if attn_metadata.block_list is None:
return 0
return attn_metadata.block_list.numel()
def _phase(self, attn_metadata):
phase_type: PhaseType
is_prompt = attn_metadata.is_prompt
is_prefix_prefill = is_prompt and attn_metadata.block_list is not None
if is_prompt and is_prefix_prefill:
phase_type = PhaseType.PREFIX_PREFILL
elif is_prompt and not is_prefix_prefill:
phase_type = PhaseType.PREFILL
elif not is_prompt:
phase_type = PhaseType.DECODE
else:
raise ValueError("Unrecognized pass type, likely due to malformed "
"attention metadata")
return phase_type
def _check_config(self, batch_size, seq_len, attn_metadata, warmup_mode):
is_prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
cfg: Optional[tuple] = None
assert cfg is None, "Configs changed between 2D and 3D"
if is_prefix_caching:
phase = self._phase(attn_metadata)
num_blocks = self._num_blocks(attn_metadata)
cfg = (batch_size, seq_len, num_blocks, phase)
else:
phase = 'prompt' if attn_metadata.is_prompt else 'decode'
cfg = (batch_size, seq_len, phase)
seen = cfg in self.seen_configs
self.seen_configs.add(cfg)
if not seen and not warmup_mode:
logger.warning("Configuration: %s was not warmed-up!",
(phase.value, batch_size, seq_len,
num_blocks) if is_prefix_caching else
(phase, batch_size, seq_len))
def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
is_prompt: bool):
'''
This is a helper function to create the mask for lora computations.
Lora Mask is needed to ensure we match the correct lora weights for the
for the request.
For Prompt phase we have
lora_mask with shape (batch_size * seq_len, max_loras * max_rank)
lora_logits_mask with shape (batch_size, max_loras * max_rank)
For Decode phase we have both
lora_mask and lora_logits_mask with shape
(batch_size, max_loras * max_rank)
'''
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
lora_index = 0
if self.lora_config:
if is_prompt:
lora_mask = torch.zeros(
input_tokens.shape[0] * input_tokens.shape[1],
(self.lora_config.max_loras) *\
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
lora_logits_mask = torch.zeros(
input_tokens.shape[0], (self.lora_config.max_loras) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
ones = torch.ones(input_tokens.shape[1],
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
logit_ones = torch.ones(1,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
for i in range(len(lora_ids)):
if lora_ids[i] == 0:
continue
lora_index = self.lora_manager._adapter_manager.\
lora_index_to_id.index(lora_ids[i])
start_row = i * input_tokens.shape[1]
end_row = start_row + input_tokens.shape[1]
start_col = lora_index * self.lora_config.max_lora_rank
end_col = start_col + self.lora_config.max_lora_rank
lora_mask[start_row:end_row, start_col:end_col] = ones
lora_logits_mask[i, start_col:end_col] = logit_ones
lora_mask = lora_mask.to('hpu')
lora_logits_mask = lora_logits_mask.to('hpu')
else:
lora_mask = torch.zeros(input_tokens.shape[0],
(self.lora_config.max_loras) *
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
ones = torch.ones(1,
self.lora_config.max_lora_rank,
dtype=self.lora_config.lora_dtype)
for i in range(len(lora_ids)):
if lora_ids[i] == 0:
continue
lora_index = self.lora_manager._adapter_manager.\
lora_index_to_id.index(lora_ids[i])
start_pos = lora_index * self.lora_config.max_lora_rank
end_pos = start_pos + self.lora_config.max_lora_rank
lora_mask[i, start_pos:end_pos] = ones
lora_mask = lora_mask.to('hpu')
lora_logits_mask = lora_mask
return lora_mask, lora_logits_mask
def _get_seq_ids(self, model_input):
return ([
sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups
])
def _pad_to_max_num_seqs(self, tensor, value):
padding_needed = self.max_num_seqs - tensor.size(0)
if padding_needed:
padding = torch.full((padding_needed, *tensor.shape[1:]),
value,
device=tensor.device,
dtype=tensor.dtype)
tensor = torch.cat([tensor, padding])
return tensor
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForHPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
warmup_mode=False,
seqs=None,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING
use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
assert not (use_delayed_sampling and num_steps != 1), \
'Delayed sampling is not compatible with MSS!'
assert model_input.input_tokens is not None
if use_delayed_sampling and not model_input.is_prompt and \
self.is_driver_worker:
num_cached = len(self.cached_step_outputs)
assert num_cached > 0
cur_seq_ids = self._get_seq_ids(model_input)
cur_seq_id_pos = {
sid: idx
for idx, sid in enumerate(cur_seq_ids) if sid >= 0
}
htorch.core.mark_step()
for i in range(num_cached):
prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i])
target_indices = [
cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids
]
padding = self.cached_step_outputs[i].size(0) - len(
target_indices)
target_indices.extend([-1] * padding)
target_indices = torch.tensor(
target_indices,
device=model_input.input_tokens.device,
dtype=model_input.input_tokens.dtype)
model_input.input_tokens.index_copy_(
0, target_indices, self.cached_step_outputs[i])
htorch.core.mark_step()
if not model_input.is_first_multi_step:
if not model_input.is_last_step:
# not first or last multi-step
return []
# last multi-step
output = self._decode_sampler_outputs(
model_input) if self.is_driver_worker else []
torch.hpu.synchronize()
if model_input.is_first_multi_step:
# first multi-step
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
# Rank!=0 workers has is_prompt==None
if use_delayed_sampling and not model_input.is_prompt and \
model_input.input_tokens.size(1) == 1:
if self.is_driver_worker:
model_kwargs_broadcast_data = {
"input_tokens": model_input.input_tokens
}
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
input_tokens = model_input.input_tokens
else:
model_kwargs_broadcast_data = broadcast_tensor_dict(src=0)
input_tokens = model_kwargs_broadcast_data["input_tokens"]
else:
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
sampling_metadata = model_input.sampling_metadata
real_batch_size = model_input.real_batch_size
batch_size_padded = model_input.batch_size_padded
assert input_tokens is not None
assert input_positions is not None
assert sampling_metadata is not None
assert attn_metadata is not None
is_prompt = attn_metadata.is_prompt
assert is_prompt is not None
batch_size = input_tokens.size(0)
seq_len = self._seq_len(attn_metadata)
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
self._check_config(batch_size, seq_len, attn_metadata, warmup_mode)
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
if self.lora_config:
assert model_input.lora_ids is not None
lora_mask, lora_logits_mask = self.create_lora_mask(
input_tokens, model_input.lora_ids,
attn_metadata.is_prompt)
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors,
"lora_mask": lora_mask,
"virtual_engine": model_input.virtual_engine,
**(model_input.multi_modal_kwargs or {}),
}
if htorch.utils.internal.is_lazy():
execute_model_kwargs.update(
{"bypass_hpu_graphs": not use_graphs})
htorch.core.mark_step()
if self.is_driver_worker:
model_event_name = ("model_"
f"{'prompt' if is_prompt else 'decode'}_"
f"bs{batch_size}_"
f"seq{seq_len}_"
f"graphs{'T' if use_graphs else 'F'}")
else:
model_event_name = 'model_executable'
if num_steps > 1 or use_delayed_sampling:
# in case of multi-step scheduling
# we only want to pythonize in the last step
sampling_metadata.skip_sampler_cpu_output = True
self.model.sampler.include_gpu_probs_tensor = True
cache_orig_output_tokens_len: List[Dict] = []
def try_revert_dummy_output_tokens():
if len(cache_orig_output_tokens_len) > 0:
# Reuse the original output token ids length
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
for j, data in seq_group_metadata.seq_data.items():
orig_output_tokens_len = \
cache_orig_output_tokens_len[i][j]
data.output_token_ids = \
data.output_token_ids[:orig_output_tokens_len]
for i in range(num_steps):
if i != 0 and not self.is_driver_worker:
broadcast_data = broadcast_tensor_dict(src=0)
if 'early_exit' in broadcast_data and broadcast_data[
'early_exit']:
return [output] if num_steps == 1 else []
execute_model_kwargs.update({
"input_ids":
broadcast_data["input_ids"],
"positions":
broadcast_data["positions"],
"attn_metadata":
self.trim_attn_metadata(
broadcast_data["attn_metadata"])
})
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.
selected_token_indices)
if self.lora_config:
LoraMask.setLoraMask(
lora_logits_mask.index_select(
0, sampling_metadata.selected_token_indices))
# Compute the logits.
with self.profiler.record_event(
'internal',
('compute_logits_'
f'{"prompt" if is_prompt else "decode"}_bs'
f'{batch_size}_'
f'seq{seq_len}')):
if num_steps == 1:
sampling_metadata.selected_token_indices = None
logits = self.model.compute_logits(hidden_states,
sampling_metadata)
htorch.core.mark_step()
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
continue
if use_delayed_sampling:
fake_output = self._delayed_sampler_outputs(model_input)
with self.profiler.record_event(
'internal', ('sample_'
f'{"prompt" if is_prompt else "decode"}_'
f'bs{batch_size}_'
f'seq{seq_len}')):
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
if num_steps > 1:
output = output.sampled_token_ids
self.cached_step_outputs.append(output)
if use_delayed_sampling and self.is_driver_worker:
self._patch_prev_output()
output = self._pad_to_max_num_seqs(
output.sampled_token_ids, DUMMY_TOKEN_ID)
self.cached_step_outputs.append(output)
self.cached_step_inputs.append(model_input)
htorch.core.mark_step()
if model_input.async_callback is not None:
model_input.async_callback()
if i < num_steps - 1:
if i == 0:
if model_input.async_callback is not None:
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
seq_group_metadata_list = \
ctx.seq_group_metadata_list
elif seqs is not None:
seq_group_metadata_list = seqs
else:
raise RuntimeError(
"seq_group_metadata_list is uninitialized")
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
# Skip empty steps
seq_group_metadata.state.current_step += (
num_steps - 2)
# Cache the original output token ids
cache_orig_output_tokens_len.append({})
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[i][j] = \
len(data.output_token_ids)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
max_output_len = sampling_metadata.seq_groups[
0].sampling_params.max_tokens
if len(data.output_token_ids) < max_output_len - 1:
# add a place holder for prepare_decode
# arbitrary value, this could be any token
dummy_token = (540, )
data.output_token_ids += (dummy_token)
else:
broadcast_tensor_dict({'early_exit': True},
src=0)
if num_steps == 1:
return [output]
else:
try_revert_dummy_output_tokens()
return []
result = self._prepare_decode(seq_group_metadata_list,
output=output)
execute_model_kwargs.update({
"input_ids":
result.input_tokens,
"positions":
result.input_positions,
"attn_metadata":
self.trim_attn_metadata(result.attn_metadata)
})
model_kwargs_broadcast_data = {
"input_ids": result.input_tokens,
"positions": result.input_positions,
"attn_metadata": vars(result.attn_metadata)
}
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
else:
try_revert_dummy_output_tokens()
if self.is_driver_worker and self.profiler.enabled:
# Stop recording 'execute_model' event
self.profiler.end()
event_end = self.profiler.get_timestamp_us()
counters = self.profiler_counter_helper.get_counter_dict(
cache_config=self.cache_config,
duration=event_end - self.event_start,
seq_len=seq_len,
batch_size_padded=batch_size_padded,
real_batch_size=real_batch_size,
is_prompt=is_prompt)
self.profiler.record_counter(self.event_start, counters)
if num_steps == 1:
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None
if model_input.is_prompt:
output.prefill_hidden_states = hidden_states
output.hidden_states = hidden_states
if use_delayed_sampling:
if self.is_driver_worker:
return [fake_output]
else:
return []
return [output] if self.is_driver_worker else []
else:
return []
return output if type(output) is list else [output]
def _delayed_sampler_outputs(self, model_input):
next_token_ids = [[DUMMY_TOKEN_ID]] * len(
model_input.sampling_metadata.seq_groups)
sampler_output = self._make_decode_output(
next_token_ids, model_input.sampling_metadata.seq_groups)
return sampler_output
def _decode_sampler_outputs(self, model_input):
use_async_out_proc = model_input.async_callback is not None
sampler_outputs = []
num_outputs = len(self.cached_step_outputs)
for i in range(num_outputs):
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = self._make_decode_output(
next_token_ids, model_input.sampling_metadata.seq_groups)
sampler_outputs.append(sampler_output)
if i < num_outputs - 1 and use_async_out_proc:
assert model_input.async_callback is not None
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
ctx.append_output(
outputs=[sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False,
is_first_step_output=False)
model_input.async_callback()
if use_async_out_proc:
return [sampler_outputs[-1]]
else:
return sampler_outputs
def _make_decode_output(
self,
next_token_ids: List[List[int]],
seq_groups: List[SequenceGroupToSample],
) -> SamplerOutput:
zero_logprob = Logprob(0.0)
sampler_outputs = []
batch_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
seq_outputs = []
for seq_id in seq_ids:
next_token_id = next_token_ids[batch_idx][0]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return SamplerOutput(sampler_outputs)
def shutdown_inc(self):
can_finalize_inc = False
from contextlib import suppress
with suppress(AttributeError):
can_finalize_inc = (self.model_config.quantization == 'inc') and \
(self.model.model is not None) and \
self.inc_initialized_successfully and \
not getattr(self, "_is_inc_finalized", False)
if can_finalize_inc:
from neural_compressor.torch.quantization import (
finalize_calibration)
finalize_calibration(self.model.model)
self._is_inc_finalized = True
def __del__(self):
self.shutdown_inc()
def _patch_prev_output(self):
assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \
f'''Inputs and outputs are out of sync!
{len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}'''
if len(self.cached_step_inputs) == 0:
return
model_input = self.cached_step_inputs.pop(0)
delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze(
-1).tolist()
ctx = model_input.async_callback.keywords["ctx"] # type: ignore
# If there's no output to patch with, which is usually the case when
# we're starting a new request after all requests are completed.
if len(ctx.output_queue) == 0:
return
assert len(
ctx.output_queue) == 1, 'There should be exactly 1 output waiting!'
output_data = ctx.output_queue[0]
assert len(output_data.outputs) == 1
for fake_out, real_out in zip(output_data.outputs[0], delayed_output):
fake_out.samples[0].output_token = real_out
for sg, real_out in zip(output_data.seq_group_metadata_list,
delayed_output):
assert len(sg.seq_data) == 1
seq_data = list(sg.seq_data.values())[0]
# This is a hack. Assigning output_token_ids triggers
# a cache recomputation and we only need to update the last token
seq_data.output_token_ids_array[-1] = real_out
seq_data._cached_all_token_ids[-1] = real_out
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import contextlib
import gc
import os
from typing import List, Optional, Set, Tuple, Type
import habana_frameworks.torch as htorch # noqa:F401
import torch
import torch.distributed
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import bind_kv_cache
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.hpu_model_runner import HPUModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class HPUWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a HPU.
Each worker is associated with a single HPU. The worker is responsible for
maintaining the KV cache and executing the model on the HPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[ModelRunnerBase]] = None,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner: HPUModelRunner = HPUModelRunner(
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[HPUCacheEngine]
# Initialize gpu_cache as pooling models don't initialize kv_caches
self.hpu_cache: Optional[List[List[torch.Tensor]]] = None
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.HPU,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
def _set_env_vars(self):
local_rank = self.local_rank
if self.parallel_config.world_size == 1:
local_rank = -1
import os
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["ID"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(self.parallel_config.world_size)
os.environ["RANK"] = str(self.rank)
def init_device(self) -> None:
if self.device_config.device.type == "hpu":
self.device = torch.device("hpu")
torch.hpu.set_device(self.device)
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
if self.model_config.quantization == 'inc':
self._set_env_vars()
init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[List[SamplerOutput]]:
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501
log_graph_compilation_all = os.environ.get(
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0'
log_graph_compilation = os.environ.get(
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION',
'0') != '0' or log_graph_compilation_all
log_cpu_fallbacks_all = os.environ.get(
'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0'
log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS',
'0') != '0' or log_cpu_fallbacks_all
if (log_graph_compilation or log_cpu_fallbacks) and \
execute_model_req is not None:
from habana_frameworks.torch.hpu.metrics import metric_localcontext
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
is_prompt = any([
seq_group_metadata.is_prompt
for seq_group_metadata in seq_group_metadata_list
])
max_context_len = max([
max([
len(v.prompt_token_ids) + len(v.output_token_ids)
for v in seq_group_metadata.seq_data.values()
]) for seq_group_metadata in seq_group_metadata_list
]) # whoa, that's some spicy stuff right here
max_num_blocks = (
(max_context_len - 1) // self.cache_config.block_size) + 1
input_stats = (f'is_prompt: {is_prompt}, '
f'num_seqs: {len(seq_group_metadata_list)}, '
f'max_context_len: {max_context_len}, '
f'max_num_blocks {max_num_blocks}')
gc_ctx = metric_localcontext(
"graph_compilation"
) if log_graph_compilation else contextlib.nullcontext()
cpu_fallback_ctx = metric_localcontext(
"cpu_fallback"
) if log_cpu_fallbacks else contextlib.nullcontext()
with gc_ctx as gc_local_metric, \
cpu_fallback_ctx as cpu_fallback_local_metric:
output = LocalOrDistributedWorkerBase.execute_model(
self, execute_model_req)
if (log_graph_compilation and gc_local_metric.stats()[0][1]
> 0) or log_graph_compilation_all:
msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: "
f"{gc_local_metric.stats()}, {input_stats}")
logger.warning(msg)
if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1]
> 0) or log_cpu_fallbacks_all:
msg = ("VLLM_HPU_STEP_CPU_FALLBACK: "
f"{cpu_fallback_local_metric.stats()}, {input_stats}")
logger.warning(msg)
return output
output = LocalOrDistributedWorkerBase.execute_model(
self, execute_model_req)
return output
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with HabanaMemoryProfiler() as m:
self.model_runner.profile_run()
torch.hpu.synchronize()
msg = ("Model profiling run "
f"took {m.get_summary_string()}")
logger.info(msg)
# At this point we should've allocated the maximum workspace for all
# recipes we will use the extra memory for graphs/blocks
free_hpu_memory = torch.hpu.mem_get_info()[0]
cache_block_size = self.get_cache_block_size_bytes()
graph_reserved_mem = (float(
os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.1'))
if not self.model_config.enforce_eager else 0)
graph_headroom = 1 - graph_reserved_mem
available_hpu_memory = free_hpu_memory * \
self.cache_config.gpu_memory_utilization
hpu_memory_margin = free_hpu_memory * (
1 - self.cache_config.gpu_memory_utilization)
self.model_runner.mem_margin = hpu_memory_margin
cache_size_bytes = available_hpu_memory * graph_headroom
graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom)
msg = (
f"Free device memory: {format_bytes(free_hpu_memory)}, "
f"{format_bytes(available_hpu_memory)} usable "
f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization}),"
f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs "
f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), "
f"{format_bytes(cache_size_bytes)} reserved for KV cache")
logger.info(msg)
num_hpu_blocks = int(cache_size_bytes // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_hpu_blocks = max(num_hpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
self.model_runner.bucketing_ctx.num_hpu_blocks = num_hpu_blocks
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
return num_hpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
This also warms up the model, which may record CUDA graphs.
"""
raise_if_cache_size_invalid(
num_gpu_blocks, self.cache_config.block_size,
self.model_config.max_model_len,
self.parallel_config.pipeline_parallel_size)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
with HabanaMemoryProfiler() as m:
self._init_cache_engine()
torch.hpu.synchronize()
msg = ("Initializing cache engine "
f"took {m.get_summary_string()}")
logger.info(msg)
self._warm_up_model()
def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
HPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.hpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.hpu_cache)
def _warm_up_model(self) -> None:
# NOTE(kzawora): We should use virtual engine index here
# for pipeline parallelism. Using 0 for now.
assert self.hpu_cache is not None
self.model_runner.warmup_model(self.hpu_cache[0])
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def finish_measurements(self):
self.model_runner.finish_measurements()
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.hpu_cache
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
device="cpu",
dtype=torch.int64).view(-1, 2)
blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
device="cpu",
dtype=torch.int64).view(-1, 2)
# `blocks_to_copy` is a gpu tensor. The src and tgt of
# blocks to copy are in the same device, and `blocks_to_copy`
# can be used directly within cuda kernels.
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
@torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
# Issue cache operations.
if (worker_input.blocks_to_swap_in is not None
and worker_input.blocks_to_swap_in.numel() > 0):
self.cache_engine[virtual_engine].swap_in(
worker_input.blocks_to_swap_in)
if (worker_input.blocks_to_swap_out is not None
and worker_input.blocks_to_swap_out.numel() > 0):
self.cache_engine[virtual_engine].swap_out(
worker_input.blocks_to_swap_out)
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def shutdown_inc(self):
self.model_runner.shutdown_inc()
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size
def get_cache_block_size_bytes(self) -> int:
"""Get the size of the KV cache block size in bytes.
"""
return HPUCacheEngine.get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
init_distributed_environment(parallel_config.world_size,
rank,
distributed_init_method,
local_rank,
backend=current_platform.dist_backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
torch.distributed.init_process_group(
backend="hccl",
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup & checking conformance.
dummy_tensor_hpu = torch.ones(1).to('hpu')
torch.distributed.all_reduce(dummy_tensor_hpu)
assert dummy_tensor_hpu.item() == parallel_config.world_size
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,
pipeline_parallel_size) -> None:
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size)
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
class HPUCacheEngine(CacheEngine):
def _allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
for _ in range(self.num_attention_layers):
key_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
device=device)
value_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
device=device)
kv_layer = (key_cache, value_cache)
kv_cache.append(kv_layer)
return kv_cache
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
###############################################################################
# Copyright (C) 2025 Habana Labs, Ltd. an Intel Company
###############################################################################
import dataclasses
from typing import Dict, Optional, Tuple
import torch
from vllm.distributed import broadcast_tensor_dict
from vllm.sequence import ExecuteModelRequest
from vllm.worker.hpu_model_runner import ModelInputForHPU
from vllm.worker.hpu_worker import HPUWorker
from vllm.worker.worker_base import WorkerInput
class MultiStepHPUWorker(HPUWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cached_model_input: Optional[ModelInputForHPU] = None
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelInputForHPU, WorkerInput, Dict[str, torch.Tensor]]:
"""
Get the driver input and broadcast it to other workers.
"""
assert self.is_driver_worker
assert execute_model_req.virtual_engine == 0
is_first_multi_step = execute_model_req.is_first_multi_step
is_last_step = execute_model_req.is_last_step
if is_first_multi_step:
# on first step we prepare the worker input and model input normally
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
worker_input = dataclasses.replace(
worker_input,
num_steps=execute_model_req.num_lookahead_slots + 1)
model_input: ModelInputForHPU = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if execute_model_req.async_callback:
model_input = dataclasses.replace(
model_input,
async_callback=execute_model_req.async_callback)
else:
# on subsequent steps we reuse the worker input and model input
assert self.cached_model_input is not None
model_input = self.cached_model_input
worker_input = WorkerInput()
model_input = dataclasses.replace(
model_input,
is_first_multi_step=is_first_multi_step,
is_last_step=is_last_step)
if self.do_metadata_broadcast:
if is_first_multi_step:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
else:
broadcast_data = {
"is_first_multi_step": is_first_multi_step,
"is_last_step": is_last_step,
}
broadcast_tensor_dict(broadcast_data, src=0)
# Returning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return model_input, worker_input, {}
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[ModelInputForHPU, WorkerInput, Dict[str,
torch.Tensor]]]:
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
execute_model_req)
if model_input.is_first_multi_step:
self.cached_model_input = model_input
return model_input, worker_input, {}
else:
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
if len(broadcast_data) == 2:
assert self.cached_model_input is not None
self.cached_model_input = dataclasses.replace(
self.cached_model_input,
is_first_multi_step=broadcast_data["is_first_multi_step"],
is_last_step=broadcast_data["is_last_step"])
empty_worker_input = WorkerInput()
return self.cached_model_input, empty_worker_input, {}
worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
self.cached_model_input = model_input
return model_input, worker_input, {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment