Commit 3de379de authored by zhuwenwen's avatar zhuwenwen
Browse files

update unused code

parent 5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import weakref
from typing import List, Optional, Set, Tuple, Dict
import torch
import torch.nn.functional as F
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import DelegateWorkerBase
from vllm.spec_decode.tree_style_proposer import TreeStyleProposer
from vllm.distributed import broadcast_tensor_dict
from vllm.worker.worker_base import WorkerWrapperBase
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
"""Worker for Medusa.
"""
def __init__(self, *args, **kwargs):
# skip lora config in medusa
DelegateWorkerBase.__init__(self, *args, **kwargs)
# Lazy initialization list.
self._proposer: SpeculativeProposer
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def init_device(self):
self.worker.init_device()
def load_model(self):
super().load_model()
# get medusa choices and generate medusa_buffers
self.medusa_buffers = None
if self.tree_decoding and hasattr(self.model_runner.model, 'medusa_choices'):
self.medusa_choices = self.model_runner.model.medusa_choices
if self.medusa_choices is not None:
self.medusa_buffers = self.generate_medusa_buffers(
self.medusa_choices, device=self.device
)
if self.medusa_buffers is None:
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
else:
self._proposer = TreeStyleProposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
self.medusa_buffers,
max_proposal_len=self.max_model_len,
)
def set_include_gpu_probs_tensor(self):
pass
def set_should_modify_greedy_probs_inplace(self):
pass
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Dict[str, torch.Tensor]:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
sample_indices_list = []
for seq_group in sampling_metadata.seq_groups:
sample_indices_list.append(seq_group.sample_indices)
previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
previous_logits = execute_model_req.previous_logits.logits if \
execute_model_req.previous_logits is not None else None
tensor_dict = {
"previous_hidden_states": previous_hidden_states,
"previous_logits": previous_logits,
"sample_indices_list": sample_indices_list,
"seq_lens": seq_lens
}
if self.do_metadata_broadcast:
broadcast_tensor_dict(tensor_dict, src=0)
return tensor_dict
def _get_worker_input_from_broadcast(
self
) -> Optional[Dict[str, torch.Tensor]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
return broadcast_data
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
# Unused parameter.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For medusa worker, this indicator shall be False.
"""
self._raise_if_unsupported(execute_model_req)
if self.is_driver_worker:
tensor_dict = self._get_driver_input_and_broadcast(execute_model_req)
else:
tensor_dict = self._get_worker_input_from_broadcast()
if tensor_dict is None:
raise ValueError("Can not get inputs of medusa worker!!!")
model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=tensor_dict["previous_hidden_states"],
sample_indices_list=tensor_dict["sample_indices_list"],
previous_logits=tensor_dict["previous_logits"],
medusa_buffers=self.medusa_buffers)
# create tree attn masks
if self.is_driver_worker and self.medusa_buffers is not None:
seq_lens = tensor_dict["seq_lens"]
max_context_len = max(seq_lens)
for sampler_output, seq_len in zip(model_outputs, seq_lens):
context_len = seq_len
attn_masks = self.medusa_buffers['tree_attn_masks']
left_mask = torch.ones(attn_masks.shape[0], context_len,
dtype=attn_masks.dtype,
device=attn_masks.device)
attn_masks = torch.cat([left_mask, attn_masks], dim=-1)
right_pad = max_context_len - context_len
if right_pad > 0:
attn_masks = F.pad(attn_masks, (0, right_pad), "constant", 0)
sampler_output.tree_attn_masks = attn_masks
return model_outputs, False
def _prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[List[int], List[int]]:
if not seq_group_metadata_list:
return [], []
seq_lens: List[int] = []
query_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
is_prompt = seq_group_metadata.is_prompt
for seq_data in seq_group_metadata.seq_data.values():
seq_data_len = seq_data.get_len()
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
seq_len = min(
seq_data_len,
context_len + seq_group_metadata.token_chunk_size)
seq_lens.append(seq_len)
query_lens.append(seq_len - context_len)
else:
# first step of tree decoding need to ignore first token
if self.medusa_buffers is not None and seq_data.get_first_step_flag():
seq_data_len -= 1
seq_lens.append(seq_data_len)
query_lens.append(1)
return seq_lens, query_lens
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_spec_proposals(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
def _raise_if_unsupported(
self,
execute_model_req: ExecuteModelRequest,
) -> None:
"""MedusaWorker does not yet implement support for cache swap
operations or beam search.
"""
if execute_model_req is None:
return None
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"MedusaWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MedusaWorker does not support beam search.")
def pad_path(self, path, length, pad_value=-2):
"""
Pad the given path list with a specific value up to a specified length.
Parameters:
- path (list): The original list that needs padding.
- length (int): The desired length of the padded list.
- pad_value (optional, default=-2): The value to use for padding.
Returns:
- list: A new list based on the original path but padded to the desired length.
Example:
>>> pad_path([1,2,3], 5)
[1, 2, 3, -2, -2]
Note:
If the given path is already longer than the specified length,
then no padding occurs, and the original path is returned.
"""
# Calculate the number of padding values needed by subtracting the length
# of the path from the desired length.
# Append the padding values to the original path and return the new list.
return path + [pad_value] * (length - len(path))
def generate_medusa_buffers(self, medusa_choices, device="cuda"):
"""
Generate buffers for the Medusa structure based on the provided choices.
Parameters:
- medusa_choices (list): A nested list representing tree in the Medusa structure.
- device (str): Device to which the tensors should be moved. Default is "cuda".
Returns:
- dict: A dictionary containing buffers related to the Medusa structure.
"""
# Sort the medusa_choices based on their lengths and then their values
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
medusa_len = len(sorted_medusa_choices) + 1
# Initialize depth_counts to keep track of how many choices have a particular depth
depth_counts = []
prev_depth = 0
for path in sorted_medusa_choices:
depth = len(path)
if depth != prev_depth:
depth_counts.append(0)
depth_counts[depth - 1] += 1
prev_depth = depth
# Create the attention mask for Medusa
medusa_attn_mask = torch.eye(medusa_len, medusa_len)
medusa_attn_mask[:, 0] = 1
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
# retrieve ancestor position
if len(cur_medusa_choice) == 1:
continue
ancestor_idx = []
for c in range(len(cur_medusa_choice) - 1):
ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
medusa_attn_mask[j + start + 1, ancestor_idx] = 1
start += depth_counts[i]
# Generate tree indices for the Medusa structure
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
medusa_tree_indices[0] = 0
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
start += depth_counts[i]
# Generate position IDs for the Medusa structure
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
start = 0
for i in range(len(depth_counts)):
medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
start += depth_counts[i]
# Generate retrieval indices for Medusa structure verification
retrieve_indices_nest = []
retrieve_paths = []
for i in range(len(sorted_medusa_choices)):
cur_medusa_choice = sorted_medusa_choices[-i-1]
retrieve_indice = []
if cur_medusa_choice in retrieve_paths:
continue
else:
for c in range(len(cur_medusa_choice)):
retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
retrieve_paths.append(cur_medusa_choice[:c+1])
retrieve_indices_nest.append(retrieve_indice)
max_length = max([len(x) for x in retrieve_indices_nest])
retrieve_indices = [self.pad_path(path, max_length) for path in retrieve_indices_nest]
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
retrieve_indices = retrieve_indices + 1
retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
# Aggregate the generated buffers into a dictionary
medusa_buffers = {
"tree_attn_masks": medusa_attn_mask.int(),
"tree_indices": medusa_tree_indices,
"tree_position_ids": medusa_position_ids,
"retrieve_indices": retrieve_indices,
}
# Move the tensors in the dictionary to the specified device
medusa_buffers = {
k: v.clone().to(device)
if isinstance(v, torch.Tensor)
else torch.tensor(v, device=device)
for k, v in medusa_buffers.items()
}
return medusa_buffers
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import Callable, Optional, Union
import msgspec
import torch
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
class SpecDecodeWorkerMetrics(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""Dataclass holding metrics emitted from the spec decode worker.
"""
# The empirical acceptance rate of the proposal method on a per-token basis.
# This is useful for evaluating how well the proposal method aligns with the
# scoring method.
draft_acceptance_rate: float
# The empirical efficiency, measured as the number of tokens emitted by the
# system divided by the number of tokens that could be emitted by the system
# if the proposal method were perfect.
system_efficiency: float
# The number of speculative tokens produced by the proposal method.
draft_tokens: int
# The number of tokens emitted by the entire system.
emitted_tokens: int
# The number of tokens accepted by the scoring model and verification
# routine, e.g. Llama2-70B and lossless rejection sampling.
#
# NOTE: Any token accepted by the verification routine is considered
# accepted (regardless of if the speculative prefix is also accepted). The
# user will usually see less accepted tokens. This metric is helpful when
# evaluating alignment of the proposal method with the scoring model.
accepted_tokens: int
# The number of speculative tokens per sequence.
num_spec_tokens: int
Timer = Callable[[], float]
class AsyncMetricsCollector:
"""Class which copies rejection/typical-acceptance sampler metrics
from the device to CPU on a non-default Torch stream.
"""
def __init__(self,
spec_decode_sampler: SpecDecodeBaseSampler,
timer: Optional[Timer] = None,
collect_interval_s: float = 5.0):
self.spec_decode_sampler = spec_decode_sampler
self._timer = time.time if timer is None else timer
self._rank: Optional[int] = None
# We don't have a device set yet.
self._copy_stream: Optional[torch.cuda.Stream] = None
self._in_flight_copy: Optional[torch.cuda.Event] = None
self._aggregate_num_draft_tokens = 0
self._rejsample_metrics_collect_interval_s = collect_interval_s
self._last_metrics_collect_time = self._timer()
def init_gpu_tensors(self, rank: int) -> None:
self._rank = rank
self._copy_stream = torch.cuda.Stream()
def init_tensors(self,
rank: int,
device_type: Union[torch.device, str] = 'cuda') -> None:
self._rank = rank
if isinstance(device_type, torch.device):
torch.cuda.set_device(device_type)
device_type = device_type.type
# stream = current_platform.Stream
# if stream is not None:
# self._copy_stream = stream()
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()
pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# Skip for any platform that doesn't have device Event
# if current_platform.Event is None:
# return None
# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
ready_event = self._in_flight_copy
self._in_flight_copy = None
return self._collect_rejsample_metrics(k, ready_event)
# Otherwise, check if we should start a new copy.
if self._should_collect_rejsample_metrics(self._timer()):
assert self._in_flight_copy is None
self._in_flight_copy = self._copy_rejsample_metrics_async()
return None
def _should_collect_rejsample_metrics(self, now: float) -> bool:
"""Return whether or not this iteration should print sampling
metrics.
"""
if self._rank != 0:
return False
return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.
Returns a device event recording when the copy is complete.
"""
assert self._copy_stream is not None
self._copy_stream.wait_stream(current_platform.current_stream())
with current_platform.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_(
self.spec_decode_sampler.num_accepted_tokens,
non_blocking=True)
self._aggregate_num_emitted_tokens.copy_(
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self._aggregate_num_draft_tokens = (
self.spec_decode_sampler.num_draft_tokens)
aggregate_metrics_ready = current_platform.Event()
aggregate_metrics_ready.record(self._copy_stream)
return aggregate_metrics_ready
def _collect_rejsample_metrics(
self, k: int,
ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
"""Create metrics object from statistics copied asynchronously.
Args:
k: int. The number of speculative tokens; used to determine system
efficiency.
ready_event: torch.cuda.Event. The CUDA event recording when the
async GPU->CPU copy is complete.
"""
ready_event.synchronize()
# update time of last collection
self._last_metrics_collect_time = self._timer()
accepted_tokens = self._aggregate_num_accepted_tokens.item()
emitted_tokens = self._aggregate_num_emitted_tokens.item()
draft_tokens = self._aggregate_num_draft_tokens
max_num_emitted_tokens = self.get_max_num_emitted_tokens(
draft_tokens, k)
if draft_tokens > 0:
draft_acceptance_rate = accepted_tokens / draft_tokens
else:
draft_acceptance_rate = float("nan")
if max_num_emitted_tokens > 0:
system_efficiency = emitted_tokens / max_num_emitted_tokens
else:
system_efficiency = float("nan")
return SpecDecodeWorkerMetrics(
num_spec_tokens=k,
draft_acceptance_rate=draft_acceptance_rate,
system_efficiency=system_efficiency,
accepted_tokens=accepted_tokens,
draft_tokens=draft_tokens,
emitted_tokens=emitted_tokens,
)
@staticmethod
def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
"""Calculate the number of emitted tokens, assuming all tokens are
accepted.
This is equal to the number of sequences that have been speculated on,
times (speculation len + 1). The +1 comes from the bonus token.
"""
# Determine the number of sequences that have been speculated on. Since
# the batch size can be variable, we divide by k.
assert draft_tokens % k == 0
total_num_spec_seqs = draft_tokens // k
# A single sequence may emit k accepted tokens and one bonus token in
# the best case.
num_emitted_per_seq_if_all_accepted = k + 1
# The max num of emitted tokens is the number of speculated sequences
# times the max emitted per seq.
return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List, Optional, Set, Tuple, Dict
import torch
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.distributed import broadcast_tensor_dict
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
"""Worker for MLPSpeculator models.
Not currently compatible with LoRA or chunked prefill.
"""
def _get_driver_input_and_broadcast(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
index: int,
last_tokens: Optional[torch.Tensor]=None,
previous_hidden_states: Optional[torch.Tensor]=None,
sampling_metadata: Optional[SamplingMetadata]=None
) -> Dict[str, torch.Tensor]:
if sampling_metadata is None and execute_model_req is not None:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
(input_tokens, seq_lens,
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
# b x 1
last_tokens = input_tokens.unsqueeze(1)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)
tensor_dict = {
"input_tokens": last_tokens,
"previous_hidden_states": previous_hidden_states,
"sample_len": sample_len,
"head_index": index
}
if self.do_metadata_broadcast:
broadcast_tensor_dict(tensor_dict, src=0)
return tensor_dict, sampling_metadata
def _get_worker_input_from_broadcast(
self
) -> Optional[Dict[str, torch.Tensor]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
return broadcast_data
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
# Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
# therefore does not need this parameter.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For mlp spec worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
model_outputs = []
last_tokens = None
previous_hidden_states = None
sampling_metadata = None
for index in range(sample_len):
if self.is_driver_worker:
tensor_dict, sampling_metadata = self._get_driver_input_and_broadcast(execute_model_req,
sample_len,
index,
last_tokens,
previous_hidden_states,
sampling_metadata)
assert sampling_metadata is not None
output, previous_hidden_states = self.model_runner.model.generate_proposals(
input_ids=tensor_dict["input_tokens"],
previous_hidden_states=tensor_dict["previous_hidden_states"],
num_predict_tokens=tensor_dict["sample_len"],
sampling_metadata=sampling_metadata,
head_index=index)
last_tokens = output.sampled_token_ids
model_outputs.append(output)
else:
tensor_dict = self._get_worker_input_from_broadcast()
if tensor_dict is None:
raise ValueError("Can not get inputs of mlp_speculator worker!!!")
self.model_runner.model.generate_proposals(
input_ids=tensor_dict["input_tokens"],
previous_hidden_states=tensor_dict["previous_hidden_states"],
num_predict_tokens=tensor_dict["sample_len"],
sampling_metadata=None,
head_index=tensor_dict["head_index"])
if self.is_driver_worker:
assert len(model_outputs) == sample_len
return model_outputs, True
def _prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, List[int], List[int]]:
if not seq_group_metadata_list:
return torch.empty(0, device=self.device), [], []
input_tokens: List[int] = []
seq_lens: List[int] = []
query_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
is_prompt = seq_group_metadata.is_prompt
for seq_data in seq_group_metadata.seq_data.values():
seq_data_len = seq_data.get_len()
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
seq_len = min(
seq_data_len,
context_len + seq_group_metadata.token_chunk_size)
tokens = seq_data.get_token_ids()[context_len:seq_len]
seq_lens.append(seq_len)
input_tokens.extend(tokens)
query_lens.append(seq_len - context_len)
else:
seq_lens.append(seq_data_len)
input_tokens.append(seq_data.get_last_token_id())
query_lens.append(1)
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
return input_tokens_tensor, seq_lens, query_lens
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import weakref
from typing import Dict, List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import DelegateWorkerBase
class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
by invoking the scheduler less.
The MultiStepWorker does not support cache swap operations, or beam search.
Cache swap operations do not require large modifications. On the other hand,
beam search requires memory allocations during sequence forks and thus
requires more thought for MultiStepWorker support.
"""
def __init__(self, *args, **kwargs):
DelegateWorkerBase.__init__(self, *args, **kwargs)
# Lazy initialization list.
self._proposer: SpeculativeProposer
def init_device(self) -> None:
self.worker.init_device()
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.sampler.include_gpu_probs_tensor = True
if hasattr(self.model_runner.model, "sampler"):
(self.model_runner.model.sampler.include_gpu_probs_tensor) = True
def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.sampler.should_modify_greedy_probs_inplace = True
if hasattr(self.model_runner.model, "sampler"):
(self.model_runner.model.sampler.should_modify_greedy_probs_inplace
) = True
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
if expanded_request.previous_hidden_states is not None:
self.worker.model_runner.return_hidden_states = True
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._maybe_update_previous_hidden_states(
model_output, expanded_request)
self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
@staticmethod
def _maybe_update_previous_hidden_states(
model_output: SamplerOutput,
expanded_request: ExecuteModelRequest) -> None:
"""
Updates the previous hidden states in an expanded request
in-place with the hidden states from the model output.
"""
if expanded_request.previous_hidden_states is not None:
expanded_request.previous_hidden_states = HiddenStates(
model_output.hidden_states,
expanded_request.seq_group_metadata_list)
@staticmethod
def _expand_execute_model_request(
execute_model_req: ExecuteModelRequest,
seq_with_bonus_token_in_last_step: set,
) -> Tuple[ExecuteModelRequest, List[int]]:
"""
Expands the execute model request based on sequences with bonus
tokens.
For each sequence with a bonus token, this method creates a new
sequence without the bonus token and adds it to the execute model
request. The original sequence groups are also retained. The indices
of the original sequence groups are returned for further processing.
Args:
execute_model_req (ExecuteModelRequest): The original execute
model request.
seq_with_bonus_token_in_last_step (set): Set of sequence IDs that
contain bonus tokens.
Returns:
Tuple[ExecuteModelRequest, List[int]]: The updated execute model
request with expanded sequences and a list of indices corresponding
to the original sequence groups.
"""
updated_seq_group_metadata_list: List[SequenceGroupMetadata] = []
updated_execute_model_req = execute_model_req.clone(
updated_seq_group_metadata_list)
indices_of_original_sequence_groups = []
for seq_group in execute_model_req.seq_group_metadata_list:
seq_group_has_bonus_tokens = False
for seq_id, _ in seq_group.seq_data.items():
# Identify sequences with bonus tokens in the sequence group.
if seq_id in seq_with_bonus_token_in_last_step:
seq_group_has_bonus_tokens = True
break
if seq_group_has_bonus_tokens:
#Create new sequences without the last bonus token. These new
# sequence have the same sequence id as the original sequence.
# We create a new sequence group and add them there.
updated_seq_group_without_bonus_token = \
MultiStepWorker._copy_seq_metadata_excluding_last_token(
seq_group, seq_with_bonus_token_in_last_step)
updated_seq_group_metadata_list.append(
updated_seq_group_without_bonus_token)
# Add the original sequence group.
updated_seq_group_metadata_list.append(
MultiStepWorker._shallow_copy_seq_group_metadata(seq_group))
# Record the index of the original sequence group.
indices_of_original_sequence_groups.append(
len(updated_seq_group_metadata_list) - 1)
updated_execute_model_req.seq_group_metadata_list =\
updated_seq_group_metadata_list
if isinstance(updated_execute_model_req.previous_hidden_states,
HiddenStates):
updated_execute_model_req.previous_hidden_states\
.expand_with_bonus_tokens(seq_with_bonus_token_in_last_step)
return updated_execute_model_req, indices_of_original_sequence_groups
@staticmethod
def _filter_model_output(
expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
return [
SamplerOutput(
outputs=[
expanded_batch_output.outputs[i]
for i in output_indices_to_retain
] if len(expanded_batch_output.outputs) > 0 else [],
sampled_token_probs=(
expanded_batch_output.
sampled_token_probs[output_indices_to_retain]
if expanded_batch_output.sampled_token_probs is not None
else None),
logprobs=(
expanded_batch_output.logprobs[output_indices_to_retain]
if expanded_batch_output.logprobs is not None else None),
sampled_token_ids=(expanded_batch_output.
sampled_token_ids[output_indices_to_retain]
if expanded_batch_output.sampled_token_ids
is not None else None))
for expanded_batch_output in expanded_batch_outputs
]
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: set,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_spec_proposals(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
@staticmethod
def _append_new_tokens(
model_output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
indices_of_seq_with_bonus_tokens: List[int]) -> None:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
"""
count = 0
for index, (seq_group_metadata, sequence_group_outputs) in enumerate(
zip(seq_group_metadata_list, model_output)):
seq_group_metadata.is_prompt = False
for seq_output in sequence_group_outputs.samples:
# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
# Determine the actual token ID to be generated,
# considering bonus tokens
if index != indices_of_seq_with_bonus_tokens[count]:
bonus_seq_metadata = seq_group_metadata_list[
indices_of_seq_with_bonus_tokens[count]]
_, bonus_token_seq_data = next(
iter(bonus_seq_metadata.seq_data.items()))
token_id = bonus_token_seq_data.output_token_ids[-1]
else:
count += 1
seq.append_token_id(token_id, token_logprob.logprob,
seq_output.output_embed)
seq.update_num_computed_tokens(1)
@staticmethod
def _shallow_copy_seq_group_metadata(
seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata:
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
Helpful when the vLLM scheduler runs in the same process as the worker.
The alternative is deep-copying (or other form of deep copy); this has
performance downsides.
"""
# Shallow-copy the SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
# We must shallow-copy seq_group_metadata as is_prompt could change.
new_seq_group_metadata = copy.copy(seq_group_metadata)
# We must shallow-copy seq_data as we will append token ids
new_seq_data: Dict[int, SequenceData] = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[seq_id].output_token_ids =\
old_seq_data.output_token_ids[:]
new_seq_group_metadata.seq_data = new_seq_data
return new_seq_group_metadata
@staticmethod
def _copy_seq_metadata_excluding_last_token(
seq_group_metadata: SequenceGroupMetadata,
seq_ids_to_copy: Set[int],
) -> SequenceGroupMetadata:
"""
Creates a shallow copy of the given SequenceGroupMetadata, retaining
only the sequence IDs specified in seq_ids_to_copy. For each of these
sequence IDs, all output_token_ids except the last one are copied.
Sequence IDs not in seq_ids_to_copy are excluded from the copy.
Parameters:
seq_group_metadata (SequenceGroupMetadata): The original sequence
group metadata.
seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
copy.
Returns:
SequenceGroupMetadata: A shallow copy of the sequence group metadata
with the specified modifications.
"""
# Shallow-copy the SequenceGroupMetadata.
new_seq_group_metadata = copy.copy(seq_group_metadata)
# Shallow-copy seq_data and modify the output_token_ids.
new_seq_data: Dict[int, SequenceData] = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
if (seq_id in seq_ids_to_copy):
new_seq_data[seq_id] = copy.copy(old_seq_data)
# Copy all the output token ids except the last.
# Also reduce num_computed_tokens by 1 since we are not
# including the last output token.
# NOTE: num_computed_tokens is not directly used by the
# speculative decoding workers, as it is only relevant for
# chunked prefill, which is disabled for speculative decoding.
# However, to maintain consistency in num_computed_tokens,
# we update it here.
new_seq_data[seq_id].output_token_ids =\
old_seq_data.output_token_ids[:-1]
new_seq_data[seq_id].update_num_computed_tokens(-1)
new_seq_group_metadata.seq_data = new_seq_data
return new_seq_group_metadata
def _assert_enough_kv_space(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
num_steps: int) -> None:
"""Assert there are enough physical blocks per sequence to store the
current KV plus additional KV from num_steps tokens.
"""
assert self.model_runner.block_size is not None
for seq_group_metadata in seq_group_metadata_list:
# Only one seq_id is guaranteed because there is no beam search.
seq_id = list(seq_group_metadata.seq_data.keys())[0]
seq = seq_group_metadata.seq_data[seq_id]
# After num_steps, the seq len will be the current seq len
# plus one token per step.
final_seq_len = seq.get_len() + num_steps
# We will have final_seq_len - 1 KV because vLLM saves KV for a
# token in the iteration after the token was generated.
required_num_kv_slots = final_seq_len - 1
# The allocated number of kv slots is the number of allocated blocks
# times the number of slots of block.
number_physical_blocks = len(
seq_group_metadata.block_tables[seq_id])
allocated_kv_slots = (number_physical_blocks *
self.model_runner.block_size)
if required_num_kv_slots > allocated_kv_slots:
request_id = seq_group_metadata.request_id
raise ValueError(
"The worker attempted to run "
f"{num_steps} times but found insufficient KV space for "
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
f"{required_num_kv_slots=}).")
def _raise_if_unsupported(
self,
execute_model_req: ExecuteModelRequest,
) -> None:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if execute_model_req is None:
return None
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"MultiStepWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")
def maybe_load_lm_head_weight(
self,
lm_head_weight: torch.Tensor,
) -> None:
weight_loader = getattr(
self.worker.model_runner.model_runner.model.lm_head.weight,
"weight_loader", default_weight_loader)
weight_loader(
self.worker.model_runner.model_runner.model.lm_head.weight,
lm_head_weight)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import copy
from collections import defaultdict
from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple, Type
import torch
import torch.nn as nn
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import (broadcast_tensor_dict,
get_tp_group,
tensor_model_parallel_gather)
from vllm.distributed.parallel_state import model_parallel_is_initialized
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids, Logits)
from vllm.spec_decode.batch_expansion import BatchExpansionTreeStyleScorer
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.mqa_scorer import MQAScorer
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (Timer, create_logprobs_output,
create_sequence_group_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.utils import resolve_obj_by_qualname
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
import vllm.envs as envs
logger = init_logger(__name__)
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
"""Helper method that is the entrypoint for Executors which use
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
vllm_config: VllmConfig = kwargs.get("vllm_config")
speculative_config: SpeculativeConfig = vllm_config.speculative_config
assert speculative_config is not None
if vllm_config.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError("Speculative decoding is currently "
"incompatible with pipeline parallelism")
draft_worker_kwargs = kwargs.copy()
kwargs["model_runner_cls"] = TargetModelRunner
target_worker_config = copy.deepcopy(vllm_config)
target_worker_config.parallel_config.worker_cls =\
target_worker_config.parallel_config.sd_worker_cls
cls = resolve_obj_by_qualname(
target_worker_config.parallel_config.worker_cls)
target_worker = cls(*args, **kwargs)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker.model_runner.disable_logprobs =\
speculative_config.disable_logprobs
draft_worker_config = copy.deepcopy(vllm_config)
draft_worker_config.model_config = speculative_config.draft_model_config
# draft_worker_config.quant_config = VllmConfig._get_quantization_config(
# draft_worker_config.model_config,
# vllm_config.load_config,
# )
speculative_config.draft_parallel_config.worker_cls =\
draft_worker_config.parallel_config.sd_worker_cls
draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa
# TODO allow draft-model specific load config.
# Override draft-model specific worker args.
draft_worker_kwargs.update(
vllm_config=draft_worker_config,
ngram_prompt_lookup_max=speculative_config.prompt_lookup_max,
ngram_prompt_lookup_min=speculative_config.prompt_lookup_min,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_mqa_scorer=speculative_config.disable_mqa_scorer,
disable_by_batch_size=speculative_config.disable_by_batch_size,
draft_token_acceptance_method=speculative_config.acceptance_method,
typical_acceptance_sampler_posterior_threshold=speculative_config.
posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config.
posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats,
num_speculative_tokens=speculative_config.num_speculative_tokens,
)
return spec_decode_worker
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
class SpecDecodeWorker(LoRANotSupportedWorkerBase):
"""Worker which implements speculative decoding.
Speculative decoding reduces decoding per-token latency by using a proposal
method, such as a small draft model, to speculate ahead of a larger LLM. The
probabilities of the speculative tokens are then determined by the larger
LLM, after which some verification routine determines which (if any) of the
speculative tokens are accepted by the larger LLM.
See https://github.com/vllm-project/vllm/pull/2188 and
https://github.com/vllm-project/vllm/pull/3103 for more info.
The current implementation has the following limitations:
* Only draft-model proposal is implemented (contributions for more forms are
welcome!).
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
future work.
* All sequences in a batch must have the same proposal length, or zero. This
can be improved by having per-sequence speculation in the future.
* The scoring forward pass is done without an MQA kernel, which is
suboptimal especially as the batch size, proposal length, and sequence
lengths grow. Contributions to add a MQA scoring are welcome once
correctness tests pass.
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
"""
@classmethod
def create_worker(
cls,
scorer_worker: WorkerBase,
draft_worker_kwargs: Dict[str, Any],
disable_mqa_scorer: bool,
disable_by_batch_size: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
num_speculative_tokens: int,
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
enable_lm_head_weight_load = False
num_spec_prefill_steps = 1
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
draft_model_config = draft_worker_kwargs["vllm_config"].model_config
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'vllm_config'].parallel_config
if ngram_prompt_lookup_max > 0:
assert draft_parallel_config.tensor_parallel_size == 1
draft_worker_kwargs[
"device_type"] = scorer_worker.device_config.device.type
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
else:
draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = scorer_worker.parallel_config.tensor_parallel_size
if draft_model_config.hf_config.model_type == "mlp_speculator":
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
elif draft_model_config.hf_config.model_type == "medusa":
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
if current_platform.is_cuda_alike():
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
f"{draft_model_config.hf_config.model_type} "
"does not support TP > 1 yet")
allow_zero_draft_token_step = False
# Load lm_head weight for eagle in init_device
if draft_model_config.hf_config.model_type == "eagle":
enable_lm_head_weight_load = True
if envs.VLLM_ZERO_OVERHEAD:
assert False, (
"speculative decoding not support zero overhead scheduler yet"
)
from vllm.zero_overhead.spec_decode.muti_step_worker import ZeroOverheadMultiStepWorker
proposer_worker = ZeroOverheadMultiStepWorker(**draft_worker_kwargs)
else:
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = \
draft_model_config.hf_config.n_predict
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)
logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker))
spec_decode_sampler: SpecDecodeBaseSampler = None
if draft_token_acceptance_method == "rejection_sampler":
spec_decode_sampler = RejectionSampler()
elif draft_token_acceptance_method == "typical_acceptance_sampler":
spec_decode_sampler = TypicalAcceptanceSampler(
posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
)
logger.info(
"[Speculative Decoding] Configuring"
" SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
if not disable_mqa_scorer:
if scorer_worker.model_runner.attn_backend.get_name(
) != "FLASH_ATTN":
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend.")
if draft_model_config and \
draft_model_config.max_model_len < \
scorer_worker.model_config.max_model_len:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"draft model max_model_len is smaller than the target "
"model max_model_len.")
if not scorer_worker.model_runner.model_config.enforce_eager:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode.")
if envs.VLLM_ZERO_OVERHEAD:
from vllm.zero_overhead.spec_decode.spec_decode_worker import ZeroOverheadSpecDecodeWorker
return ZeroOverheadSpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
else:
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
def __init__(
self,
proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase,
spec_decode_sampler: SpecDecodeBaseSampler,
disable_mqa_scorer: bool = False,
disable_logprobs: bool = False,
disable_log_stats: bool = False,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
enable_lm_head_weight_load: Optional[bool] = False,
num_spec_prefill_steps: int = 1,
):
"""
Create a SpecDecodeWorker.
Args:
proposer_worker: A worker that can produce speculative tokens for
sequences.
scorer_worker: A worker that produces probabilities of speculative
tokens according to some base model. Typically a vanilla vLLM
Worker.
spec_decode_sampler: A Torch module used to perform acceptance
sampling of the draft tokens in the verification step of
speculative decoding. Currently we support two different
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_mqa_scorer: If set to True, disable the MQA scorer and use
the BatchExpansionTop1Scorer instead.
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
disable_log_stats: If set to True, disable periodic printing of
speculative stage times.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
enable_lm_head_weight_load: whether to load lm_head weight for
draft models like eagle.
num_spec_prefill_steps: number of speculative prefill steps to run
before the speculative decoding starts. This is only used when
the draft model is a deepseek_mtp model that requires prefill
kv cache separately for each MTP layer.
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
scorer_runner = getattr(self.scorer_worker, "model_runner", None)
self.generators = scorer_runner.get_generators(
) if scorer_runner else None
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._enable_lm_head_weight_load = enable_lm_head_weight_load
self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler
) if metrics_collector is None else metrics_collector
# Tracks the sequence IDs that received a bonus token ID in
# their last forward pass. Needed only if KV cache is being
# used for token generation such as in the case of MultiStepWorker.
self._seq_with_bonus_token_in_last_step: Set[int] = set()
# Tracks the currently active request ids and the sequence IDs
# corresponding to them
self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
# Tracks if the proposer worker uses the KV cache or not.
self.probs_dtype = self.spec_decode_sampler.probs_dtype
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
# Lazy initialization.
self.scorer: BatchExpansionTop1Scorer
self.disable_mqa_scorer = disable_mqa_scorer
# Hidden states from target model to pass to proposer
# in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None
self.previous_logits: Optional[Logits] = None
self.kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
self._disable_logprobs = disable_logprobs
self._disable_log_stats = disable_log_stats
self._num_spec_prefill_steps = num_spec_prefill_steps
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
"""
# The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker.
self.scorer_worker.init_device()
self.proposer_worker.init_device()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self.scorer_worker.load_model()
self.proposer_worker.load_model()
if self._enable_lm_head_weight_load:
# NOTE(Shangming): gather lm_head weight when tp enabled
target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather(
self.scorer_worker.model_runner.model_runner.model.lm_head.\
weight.data,
dim=0,
)
self.proposer_worker.maybe_load_lm_head_weight(
target_lm_head_weight)
self._metrics.init_tensors(self.rank, device_type=self.device)
if model_parallel_is_initialized():
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
device_type=self.device)
else:
self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device)
scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer:
scorer_cls = BatchExpansionTop1Scorer
logger.info("[Speculative Decoding] Use batch "
"expansion for scoring proposals.")
else:
scorer_cls = MQAScorer
logger.info(
"[Speculative Decoding] Use MQA scorer for scoring proposals.")
if not self.tree_decoding:
self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
else:
self.scorer = BatchExpansionTreeStyleScorer(
scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
self._configure_model_sampler_for_spec_decode()
def load_model(self, *args, **kwargs):
pass
def _configure_model_sampler_for_spec_decode(self):
"""Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing,
which significantly reduces overhead of sampling during verification.
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
design is to have the "move to CPU and serialize" sampling decision be
done outside of the model/sampler; this way the "last-mile" worker
object which interfaces with the scheduler can serialize and incur the
performance hit as necessary. This allows us to run the worker several
iterations in a row without incurring the "move to CPU and serialize"
performance penalty.
Since this requires a large change to vLLM, we defer it to later and
temporarily accept this broken abstraction boundary.
NOTE(cade): This will require a special check if the proposer worker
does not have a sampler (e.g. ngram speculation).
"""
(self.scorer_worker.model_runner.sampler.include_gpu_probs_tensor
) = True
# tree_style decoding modify probs in _verify_tokens
if not self.tree_decoding:
(self.scorer_worker.model_runner.sampler.
should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor()
self.proposer_worker.set_should_modify_greedy_probs_inplace()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
This is done by profiling the scorer model (which is typically the
larger of the two). Then the total memory which would be used by the
scorer cache is divided evenly between the proposer and scorer model KV,
such that the number of blocks is equal in both KV caches.
"""
num_gpu_blocks, num_cpu_blocks = (
self.scorer_worker.determine_num_available_blocks())
scorer_cache_block_size_bytes = (
self.scorer_worker.get_cache_block_size_bytes())
proposer_cache_block_size_bytes = (
self.proposer_worker.get_cache_block_size_bytes())
new_num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
num_gpu_blocks)
return new_num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the cache engine of the scorer and proposer workers.
"""
self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def get_model(self) -> nn.Module:
return self.scorer_worker.get_model()
@torch.inference_mode()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch.
"""
if self.rank != self._driver_rank:
self._run_non_driver_rank()
return []
if execute_model_req is None:
# This signals that there's no more requests to process for now.
# All workers are running infinite loop with broadcast_tensor_dict,
# and it stops the loop when the driver broadcasts an empty input.
# Send an empty input to notify all other workers to stop their
# execution loop.
broadcast_tensor_dict({}, src=0)
return []
self._track_finished_requests(execute_model_req)
disable_all_speculation = self._should_disable_all_speculation(
execute_model_req)
num_lookahead_slots = execute_model_req.num_lookahead_slots
all_prompt = True
atleast_one_prompt = False
all_zero_spec_tokens = True
for sgm in execute_model_req.seq_group_metadata_list:
all_prompt = all_prompt and sgm.is_prompt
atleast_one_prompt = atleast_one_prompt or sgm.is_prompt
all_zero_spec_tokens = all_zero_spec_tokens and (
sgm.num_speculative_tokens == 0)
if all_prompt and execute_model_req.seq_group_metadata_list:
assert num_lookahead_slots == 0, (
"Prompt only runs should have num_lookahead_slots equal to 0. "
"This should never happen, please file a bug at "
"https://github.com/vllm-project/vllm/issues")
# Speculative decoding is disabled in the following cases:
# 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch, or
# none of the requests in the batch have spec decoding enabled.
# In any of these cases, the proposer and scorer workers
# are called normally.
# We expect `num_speculative_tokens` to be None for prefills.
no_spec = (num_lookahead_slots == 0 or disable_all_speculation
or all_zero_spec_tokens)
# Broadcast how many lookahead slots are scheduled for this step, and
# whether all speculation is disabled, to all non-driver workers.
# This is required as if the number of draft model runs changes
# dynamically, the non-driver workers won't know unless we perform a
# communication to inform them.
# no_spec is used to signal non-driver worker about prefill vs decode
# stage. This is needed to ensure that order of execution of proposer
# and scorer is same in both driver and non-driver workers (i.e.,
# scorer -> proposer for prefill and proposer -> scorer in decode). This
# order is needed to support models like EAGLE that take scorer states
# as inputs.
broadcast_dict = dict(
num_lookahead_slots=num_lookahead_slots,
no_spec=no_spec,
disable_all_speculation=disable_all_speculation,
# When both chunked prefill and speculative decoding are enabled
# it is possible that the same batch contains both prefill
# and decodes. If that happens in the scorer we run the batch
# as one single forward pass. However, in the proposer we
# run them as 2 different batches - one for prefill and
# the other for decodes. The variable indicates to the non-driver
# worker that there are prefills as part of the speculative batch
# and hence it needs to run an extra prefill forward pass.
run_spec_proposer_for_prefill=atleast_one_prompt,
)
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding requires non-None seq_group_metadata_list")
self._maybe_disable_speculative_tokens(
disable_all_speculation, execute_model_req.seq_group_metadata_list)
if no_spec:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)
return self._run_speculative_decoding_step(execute_model_req,
num_lookahead_slots)
@torch.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop to perform speculative decoding
in parallel worker."""
while self._run_non_driver_rank():
pass
def _should_disable_all_speculation(
self, execute_model_req: ExecuteModelRequest) -> bool:
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
return (execute_model_req.running_queue_size
>= self.disable_by_batch_size)
def _maybe_disable_speculative_tokens(
self, disable_all_speculation: bool,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
if not disable_all_speculation:
return
for seq_group_metadata in seq_group_metadata_list:
# Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific
# state in the global data structure, but we should maintain
# this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0
def _serialize_sampler_output_no_logprobs(
self, execute_model_req: ExecuteModelRequest,
sampler_output: SamplerOutput) -> List[SamplerOutput]:
"""
Creates and returns a `SamplerOutput` with only the token IDs being
serialized to CPU and populated in `CompletionSequenceGroupOutput`.
All other parameters in `CompletionSequenceGroupOutput` related to log
probabilities are skipped.
Args:
execute_model_req (ExecuteModelRequest): The model request that
was executed.
sampler_output (SamplerOutput): The output from the sampler with
only GPU tensors populated.
Returns:
SamplerOutput: A new `SamplerOutput` instance containing a list of
`CompletionSequenceGroupOutput` objects with only token IDs
populated.
"""
seq_output_prompt_logprobs = [
seq.is_prompt and seq.sampling_params.prompt_logprobs is not None
and seq.sampling_params.prompt_logprobs > 0
for seq in execute_model_req.seq_group_metadata_list
]
# ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where(
# subtracting is faster than testing for equality
sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \
if any(seq_output_prompt_logprobs) else \
sampler_output.sampled_token_ids).tolist()
seq_data_entries = [
(seq_id, seq_data) for sg in \
execute_model_req.seq_group_metadata_list \
for seq_id, seq_data in sg.seq_data.items()
]
completion_seq_group_output_list: List[
CompletionSequenceGroupOutput] = []
output_index = 0
# Make sure the non-terminal prefill chunks are still aligned with
# their own empty output.
for idx, seq_group_meta in enumerate(
execute_model_req.seq_group_metadata_list):
needs_prompt_logprobs = seq_output_prompt_logprobs[idx]
seq_id, seq_data = seq_data_entries[idx]
if needs_prompt_logprobs:
prompt_token_ids = seq_data.get_prompt_token_ids()
# Some of these sequences may belong to non-terminal chunks,
# which may still have to report logprobs for prompts.
start = 1 if seq_data._num_computed_tokens == 0 \
else seq_data._num_computed_tokens
end = (seq_data._num_computed_tokens + \
seq_group_meta.token_chunk_size)
prompt_token_ids = prompt_token_ids[start:end]
prompt_logprobs = [
create_logprobs_output(
token_id=p_token_id,
token_id_logprob_rank=-1,
token_id_logprob=0.0,
topk_token_ids=[],
topk_logprobs=[],
) for p_token_id in prompt_token_ids
]
else:
prompt_logprobs = None
# Since we can get chunks here, we dont always have a sampled token
# (only on last chunk) but we still have to provide an output.
if not seq_group_meta.do_sample:
completion_seq_group_output_list.append(
CompletionSequenceGroupOutput(
samples=[], prompt_logprobs=prompt_logprobs))
continue
# Sequence with output.
completion_seq_group_output_list.append(
create_sequence_group_output(
token_id=sampled_token_ids_list[output_index][0],
token_id_logprob_rank=-1,
token_id_logprob=0.0,
seq_id=seq_id,
topk_token_ids=[],
topk_logprobs=[],
prompt_logprobs=prompt_logprobs))
output_index += 1
return [SamplerOutput(outputs=completion_seq_group_output_list)]
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]:
"""Run a single generation step without any speculation. The input is
sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
if self.tree_decoding and self.kvcache_slot_to_be_moved is not None:
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
self.kvcache_slot_to_be_moved = None
sampler_output = self.scorer_worker.execute_model(execute_model_req)
assert len(sampler_output) == 1
sampler_output = sampler_output[0]
# Store hidden states from target model execution, BxD.
hidden_states = sampler_output.hidden_states
if hidden_states is not None:
# Only decodes and prefill terminal chunks need a hidden state.
seq_group_meta_with_hidden = [
sg for sg in execute_model_req.seq_group_metadata_list
if sg.do_sample
]
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
# Drop hidden_states with no prediction (eg non-terminal chunks)
hidden_states = hidden_states[
torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]]
# if not skip_proposer:
# if self.previous_hidden_states is None and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states = HiddenStates(
# hidden_states, seq_group_meta_with_hidden)
# elif self.previous_hidden_states and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states.update(hidden_states,
# seq_group_meta_with_hidden)
if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_meta_with_hidden)
elif self.previous_hidden_states and len(
seq_group_meta_with_hidden):
self.previous_hidden_states.update(hidden_states,
seq_group_meta_with_hidden)
#self.previous_hidden_states.prune(seq_group_meta_with_hidden)
# Store logits from target model execution.
if self.tree_decoding:
logits = sampler_output.logits
if logits is not None:
if self.previous_logits is None:
self.previous_logits = Logits(
logits, execute_model_req.seq_group_metadata_list)
else:
self.previous_logits.update(
logits, execute_model_req.seq_group_metadata_list)
if not skip_proposer:
# We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode
# flow and execute_model doesn't need additional modifications.
execute_model_req.previous_hidden_states = \
prepare_prefill_hidden_states(
sampler_output.prefill_hidden_states)
for i in range(self._num_spec_prefill_steps):
execute_model_req.spec_step_idx = i
self.proposer_worker.execute_model(execute_model_req)
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output)
if self._disable_logprobs else
[sampler_output])
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output.sampled_token_probs = None
sampler_output.sampled_token_ids = None
sampler_output.logprobs = None
return sampler_output_to_return
def _run_non_driver_rank(self) -> bool:
"""Run proposer and verifier model in non-driver workers. This is used
for both speculation cases (num_lookahead_slots>0) and non-speculation
cases (e.g. prefill).
Returns True if there are remaining sequences to process.
"""
assert self.rank != self._driver_rank
data = broadcast_tensor_dict(src=self._driver_rank)
if not data:
return False
num_lookahead_slots = data["num_lookahead_slots"]
# In case of prefill, scorer_worker has to be run before proposer so
# that the hidden states can be propagated to proposer when needed.
if data["no_spec"]:
self.scorer_worker.execute_model()
if not data["disable_all_speculation"]:
# if not self.tree_decoding:
# # Even if num_lookahead_slots is zero, we want to run the
# # proposer model as it may have KV.
# #
# # We run the proposer once per lookahead slot. In the future we
# # should delegate how many times it runs to the proposer.
# for _ in range(max(num_lookahead_slots, 1)):
# self.proposer_worker.execute_model()
# else:
# if not data["no_spec"]:
# self.proposer_worker.sampler_output(None, None, None)
if issubclass(type(self.proposer_worker), NonLLMProposerWorkerBase):
if not data["no_spec"]:
self.proposer_worker.sampler_output(None, num_lookahead_slots, None)
else:
# Even if num_lookahead_slots is zero, we want to run the
# proposer model as it may have KV.
#
# We run the proposer once per lookahead slot. In the future we
# should delegate how many times it runs to the proposer.
for _ in range(max(num_lookahead_slots, 1)):
self.proposer_worker.execute_model()
if not data["no_spec"]:
self.scorer_worker.execute_model()
if data["run_spec_proposer_for_prefill"]:
self.proposer_worker.execute_model()
return True
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step(
self, execute_model_req: ExecuteModelRequest,
num_lookahead_slots: int) -> List[SamplerOutput]:
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
sequence, then scores each speculative token using the scoring worker.
When `enable_chunked_prefill` is set, scorer will batch decodes and
prefills, while proposer will sync its KV-cache by running an extra
forward on prefills.
Returns a list of SamplerOutput, each containing a single token per
sequence.
"""
# With prefill chunking, expect requests to have prompts first
# so that backend gets prefill|decode.
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
# Pass last hidden states from target model to proposer
execute_model_req.previous_hidden_states = self.previous_hidden_states
self.previous_hidden_states = None
# Pass last logits from target model to proposer
execute_model_req.previous_logits = self.previous_logits
self.previous_logits = None
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
self.kvcache_slot_to_be_moved = None
with Timer() as proposal_timer:
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)
if not self._allow_zero_draft_token_step and proposals.no_proposals:
#TODO: Fix it #5814
raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens")
# Pass tree attention mask and postions to target model
if self.tree_decoding:
execute_model_req.tree_attn_masks = proposals.tree_attn_masks
execute_model_req.tree_position_ids = proposals.tree_position_ids
execute_model_req.previous_hidden_states = None
with Timer() as scoring_timer:
proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
)
_, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len(
execute_model_req.seq_group_metadata_list, proposals.proposal_lens)
# With prefill chunking enabled, `non_spec_seqs` contains prefills too:
# discard decodes that have already been processed by proposer.
non_spec_indices = [
idx for idx in non_spec_indices
if execute_model_req.seq_group_metadata_list[idx].is_prompt
]
if len(non_spec_indices):
all_hidden_states = proposal_scores.hidden_states
if all_hidden_states is not None:
prefill_hidden_states = all_hidden_states[non_spec_indices]
execute_model_req.previous_hidden_states = \
prepare_prefill_hidden_states(prefill_hidden_states)
# Sync proposer KV cache for prefills.
prefill_req = execute_model_req.clone(non_spec_seqs)
# TODO avoid sampling here?
self.proposer_worker.execute_model(prefill_req)
with Timer() as verification_timer:
accepted_token_ids, target_logprobs, select_indices_list, accept_lengths = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
# move kv_caches of selected tokens to right positions
if self.tree_decoding:
self.move_caches(execute_model_req, select_indices_list, accept_lengths)
stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
scoring_timer.elapsed_time_ms,
verification_timer.elapsed_time_ms)
return self._create_output_sampler_list(
execute_model_req.seq_group_metadata_list,
accepted_token_ids,
target_logprobs=target_logprobs,
prompt_logprobs=proposal_scores.prompt_logprobs
if not self._disable_logprobs else None,
k=execute_model_req.num_lookahead_slots,
stage_times=stage_times)
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals,
max_proposal_len: int,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list = proposals.proposal_lens.tolist()
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
seq_group_metadata_list, proposal_lens_list)
original_indices = spec_indices + non_spec_indices
# Get probabilities of target model, including bonus tokens.
if non_spec_indices:
proposal_verifier_probs = proposal_scores.probs[spec_indices]
else:
proposal_verifier_probs = proposal_scores.probs
if self.tree_decoding:
retrieve_indices = proposals.retrieve_indices
proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices]
# Get non-speculative sampled tokens from target model.
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
# Get bonus tokens from target model.
bonus_token_ids = proposal_scores.token_ids[:, -1:]
if non_spec_indices:
bonus_token_ids = bonus_token_ids[spec_indices, :]
# Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
if proposal_probs is not None and non_spec_indices:
proposal_probs = proposal_probs[spec_indices]
# Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids
if non_spec_indices:
proposal_token_ids = proposal_token_ids[spec_indices]
# Get tree buffers.
cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
if cart_candidates is not None and non_spec_indices:
cart_candidates = cart_candidates[spec_indices]
# Sampler arguments
sampler_extra_kwargs: Dict[str, Any] = {}
if self.generators and isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
sampler_extra_kwargs["seeded_seqs"] = {
idx: self.generators[sgm.request_id]
for idx, sgm in enumerate(seq_group_metadata_list)
if sgm.sampling_params.seed is not None
}
if isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler):
sampler_extra_kwargs["cart_candidates"] = cart_candidates
sampler_extra_kwargs["best_candidates"] = []
sampler_extra_kwargs["accept_lengths"] = []
first_step_flags = []
for i, sgm in enumerate(seq_group_metadata_list):
seq = next(iter(sgm.seq_data.values()))
first_step_flags.append(True if seq.get_first_step_flag() else False)
sampler_extra_kwargs["first_step_flags"] = first_step_flags
accepted_token_ids = self.spec_decode_sampler(
target_with_bonus_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs,
draft_token_ids=proposal_token_ids,
**sampler_extra_kwargs,
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
if not self.tree_decoding:
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
1).clone()
else:
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len).clone()
non_spec_token_ids[:, 1:] = -1
accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids])
logprobs = proposal_scores.logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids[original_indices] = accepted_token_ids.clone()
# B x K+1 x D
hidden_states = proposal_scores.hidden_states
select_indices = None
accept_lengths = None
select_indices_list = []
if cart_candidates is None:
if hidden_states is not None:
# Only get terminal hidden states for next step
terminal_metadata = [
sg for sg in seq_group_metadata_list if sg.do_sample
]
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[-1]
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
# Drop non-terminal prefill chunks hidden states.
hidden_states = hidden_states[accepted_index !=
VLLM_INVALID_TOKEN_ID]
accepted_index = accepted_index[accepted_index !=
VLLM_INVALID_TOKEN_ID]
assert len(accepted_index) == hidden_states.shape[0] == len(
terminal_metadata)
index = accepted_index[:, None, None].expand(-1, 1,
hs_size) # b x 1 x d
second_last_token_hidden_states = hidden_states[:, -2] # b x d
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
# Store hidden states from target model for subsequent decode step
self.previous_hidden_states = HiddenStates(
hidden_states, terminal_metadata,
second_last_token_hidden_states)
else:
retrieve_indices = proposals.retrieve_indices
batch_size = len(seq_group_metadata_list)
best_candidates = sampler_extra_kwargs["best_candidates"]
accept_lengths = sampler_extra_kwargs["accept_lengths"]
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[-1]
hidden_states = hidden_states.view(batch_size, -1, hs_size)
# Store logits from target model for subsequent proposal
logits = proposal_scores.logits
logits = logits.view(batch_size, -1, logits.shape[-1])
logits = logits[:, retrieve_indices] # [batch_size, retrieve_size, max_depth, vocab_size]
previous_logits_list = []
previous_hidden_state_list = []
retrieve_indices = retrieve_indices.cpu()
for i in range(batch_size):
logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
previous_logits_list.append(logit)
select_indices = retrieve_indices[best_candidates[i], :accept_lengths[i]+1]
hidden_state = hidden_states[i, select_indices[-1]].unsqueeze(0)
select_indices_list.append(select_indices)
previous_hidden_state_list.append(hidden_state)
logits = torch.cat(previous_logits_list, dim=0)
self.previous_logits = Logits(logits, seq_group_metadata_list)
hidden_states = torch.cat(previous_hidden_state_list, dim=0) # [batch_size, 1, vocab_size]
self.previous_hidden_states = HiddenStates(hidden_states,
seq_group_metadata_list,)
return accepted_token_ids, logprobs, select_indices_list, accept_lengths
def move_caches(self, execute_model_req: ExecuteModelRequest,
select_indices_list: List[torch.Tensor],
accept_lengths: List[int]):
"""Given selected output tokens and accept length,
move kv_caches of selected tokens to right positions.
"""
seq_lens = []
for sg in execute_model_req.seq_group_metadata_list:
seq_ids = list(sg.seq_data.keys())
for seq_id in seq_ids:
seq_data = sg.seq_data[seq_id]
seq_len = seq_data.get_len()
token_chunk_size = sg.token_chunk_size
context_len = seq_len - 1
seq_len = min(seq_len, context_len + token_chunk_size)
# first step of tree-style decoding need to ignore first generated token
if seq_data.get_first_step_flag():
seq_len -= 1
# move cache is the last step of tree decoding, so set first_step_flag to false
seq_data.set_first_step_flag(False)
seq_lens.append(seq_len)
model_input = self.scorer._scorer_worker.model_input
block_tables = None
if hasattr(model_input, 'attn_metadata') and hasattr(model_input.attn_metadata, 'block_tables_list'):
block_tables = model_input.attn_metadata.block_tables_list
if block_tables is None:
raise RuntimeError("Can not get block_tables from model_input.")
cache_engine = self.scorer._scorer_worker.cache_engines[execute_model_req.virtual_engine]
block_size = cache_engine.block_size
batch_size = len(select_indices_list)
block_table_stride = len(block_tables) // batch_size
select_indices_slot_mapping = []
target_slot_mapping = []
for i in range(batch_size):
accept_legth = accept_lengths[i]
if accept_legth > 0:
select_indices = select_indices_list[i][1:] + seq_lens[i]
select_indices = select_indices.tolist()
self.compute_slot_mapping(select_indices_slot_mapping, i*block_table_stride,
select_indices, block_size, block_tables)
target_indices = torch.arange(accept_legth+1)[1:] + seq_lens[i]
target_indices = target_indices.tolist()
self.compute_slot_mapping(target_slot_mapping, i*block_table_stride,
target_indices, block_size, block_tables)
if len(select_indices_slot_mapping) >0:
select_indices_slot_tensor = torch.tensor(select_indices_slot_mapping,
dtype=torch.long,
device=self.device).view(-1, 1)
target_slot_mapping_tensor = torch.tensor(target_slot_mapping,
dtype=torch.long,
device=self.device).view(-1, 1)
src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2]
self.kvcache_slot_to_be_moved = src_dst_tensor
def _create_output_sampler_list(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
prompt_logprobs: Optional[
torch.Tensor], # shape: [nprompt_tokens, vocab_size]
k: int,
stage_times: Tuple[float, float, float],
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
batch_size, num_steps = accepted_token_ids.shape
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
if self._disable_logprobs:
# We are skipping the logprobs. Hence don't serialize the
# logprobs related tensors from the GPU. Instead create
# empty/dummy lists.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_dummy_logprob_lists(
batch_size, num_steps,
self.scorer_worker.model_config.max_logprobs)
else:
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
# Serialize all tensors into Python lists.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_logprob_lists_from_tensors(
target_logprobs_by_step, accepted_token_ids_by_step,
self.scorer_worker.model_config.max_logprobs)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
seq_group_metadata_list)
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize tensor to CPU Python list.
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
# Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
# terminal chunks will only have one generated token at time 0.
sampler_output_list: List[SamplerOutput] = []
# Prefills are not multi-step (return at most 1 token), in order to
# avoid padding or repetition to fit decodes, we separate them.
for i, sg in enumerate(seq_group_metadata_list):
if not sg.is_prompt:
# Requests are ordered as prefills|decodes=>no more prefills.
break
num_logprobs = num_logprobs_per_seq[i]
seq_kwargs = dict(token_id=-1,
token_id_logprob_rank=0,
token_id_logprob=-float('inf'),
topk_token_ids=[-1] * num_logprobs,
topk_logprobs=[-float('inf')] * num_logprobs,
seq_id=seq_ids[i])
# Terminal chunk, has token.
if sg.do_sample:
seq_kwargs.update(
dict(
token_id=accepted_token_ids[i][0].item(),
token_id_logprob_rank=accepted_token_id_ranks_by_step[
0][i],
token_id_logprob=accepted_token_id_logprobs_by_step[0]
[i],
topk_token_ids=topk_indices_by_step[0][i]
[:num_logprobs],
# output only so step is 0
topk_logprobs=topk_logprobs_by_step[0][i]
[:num_logprobs],
))
needs_plogs = (sg.sampling_params.prompt_logprobs
and sg.sampling_params.prompt_logprobs > 0)
plogs = None
if prompt_logprobs is not None:
# Even non-terminal prompt chunks can have logprobs here.
plogs = prompt_logprobs[i]
elif needs_plogs:
# Prompt logprobs are requested but `_disable_logprobs` is set.
seq_data = next(iter(sg.seq_data.values()))
# Get only the tokens in this chunk!
prompt_token_ids = seq_data.get_prompt_token_ids()
prompt_token_ids = prompt_token_ids[
seq_data.
_num_computed_tokens:seq_data._num_computed_tokens +
sg.token_chunk_size]
is_first_chunk = seq_data._num_computed_tokens == 0
# There's no prob generated for the first token in a sequence.
if is_first_chunk:
prompt_token_ids = prompt_token_ids[1:]
plogs = [
create_logprobs_output(
token_id=p_token_id,
token_id_logprob_rank=-1,
token_id_logprob=0.0,
topk_token_ids=[],
topk_logprobs=[],
) for p_token_id in prompt_token_ids
]
seq_kwargs.update(dict(prompt_logprobs=plogs))
sampler_output_list.append(
SamplerOutput(
outputs=[create_sequence_group_output(
**seq_kwargs)])) # type: ignore
# Decodes, create one SamplerOutput per-step (at most K+1).
for step_index in range(num_steps):
if all(token_id == -1 for sg, token_id in zip(
seq_group_metadata_list,
accepted_token_ids_by_step[step_index])
if not sg.is_prompt):
break
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
for sequence_index in range(batch_size):
seq_meta = seq_group_metadata_list[sequence_index]
# Prompts already processed above.
if seq_meta.is_prompt:
continue
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index]
step_output_token_ids.append(
create_sequence_group_output(
token_id=accepted_token_ids_by_step[step_index]
[sequence_index],
token_id_logprob_rank=accepted_token_id_ranks_by_step[
step_index][sequence_index],
token_id_logprob=accepted_token_id_logprobs_by_step[
step_index][sequence_index],
seq_id=seq_ids[sequence_index],
topk_token_ids=topk_indices_by_step[step_index]
[sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs],
step_index=step_index))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
# Populate the data structures needed to keep track of sequences with
# bonus tokens.
self._track_sequences_with_bonus_tokens(seq_ids,
request_ids_seq_ids_mapping,
accepted_token_ids_by_step)
maybe_rejsample_metrics = (
self._metrics.maybe_collect_rejsample_metrics(k))
if maybe_rejsample_metrics is not None and sampler_output_list:
sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics
# Log time spent in each stage periodically.
# This is periodic because the rejection sampler emits metrics
# periodically.
self._maybe_log_stage_times(*stage_times)
# First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format.
return sampler_output_list
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
scoring_time_ms: float,
verification_time_ms: float) -> None:
"""Log the speculative stage times. If stat logging is disabled, do
nothing.
"""
if self._disable_log_stats:
return
logger.info(
"SpecDecodeWorker stage times: "
"average_time_per_proposal_tok_ms=%.02f "
"scoring_time_ms=%.02f verification_time_ms=%.02f",
average_time_per_proposal_tok_ms, scoring_time_ms,
verification_time_ms)
def _create_dummy_logprob_lists(
self,
batch_size: int,
num_steps: int,
num_top_k: int,
) -> Tuple[List[List[int]], List[List[float]],
List[List[List[Optional[float]]]],
List[List[List[Optional[int]]]]]:
"""
Creates and returns four dummy lists representing token probabilities
and their ranks.
This method initializes and returns:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
batch_size (int): The size of the batch.
num_steps (int): The number of steps in the sequence.
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing four dummy lists as described above.
"""
accepted_token_id_ranks_by_step = [[-1] * batch_size
for _ in range(num_steps)]
accepted_token_id_logprobs_by_step = [[0.0] * batch_size
for _ in range(num_steps)]
topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
[None] * num_top_k for _ in range(batch_size)
] for _ in range(num_steps)]
topk_indices_by_step: List[List[List[Optional[int]]]] = [[
[None] * num_top_k for _ in range(batch_size)
] for _ in range(num_steps)]
return (accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
topk_indices_by_step)
def _create_logprob_lists_from_tensors(
self,
target_logprobs_by_step: torch.Tensor,
accepted_token_ids_by_step: torch.Tensor,
num_top_k: int,
) -> Tuple[List[List[int]], List[List[float]],
List[List[List[Optional[float]]]],
List[List[List[Optional[int]]]]]:
"""
Creates and returns four lists representing token probabilities and
their ranks.
This method initializes and returns four lists containing:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
target_logprobs_by_step (torch.Tensor): Tensor representing the
log probabilities of the target model,
shaped (num_steps, batch_size, vocab_size)
accepted_token_ids_by_step (torch.Tensor): Tensor representing
the accepted token_ids, shaped (num_steps, batch_size)
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing the lists as described above.
"""
# Serialize all tensors to CPU Python lists.
# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step_tensor,
accepted_token_id_logprobs_by_step_tensor
) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)
# Get the top-k logprobs (which may or may not include the
# logprob of the accepted token).
(topk_logprobs_by_step_tensor,
topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
k=num_top_k,
dim=-1,
)
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step_tensor.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step_tensor.tolist())
topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
topk_indices_by_step = topk_indices_by_step_tensor.tolist()
return (accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
topk_indices_by_step)
def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
"""
Removes the finished requests and their associated sequence ids from
internal book keeping data structures.
"""
for finished_request in execute_model_req.finished_requests_ids:
for seq_id in self._request_id_seq_id_mapping[finished_request]:
self._seq_with_bonus_token_in_last_step.discard(seq_id)
del self._request_id_seq_id_mapping[finished_request]
def _track_sequences_with_bonus_tokens(
self, seq_ids: List[int],
request_ids_seq_ids_mapping: Dict[str, Set[int]],
accepted_token_ids_by_step: List[List[int]]):
"""
Updates the internal data structures which keep track of sequences
which have been assigned bonus tokens in their last forward pass.
"""
for seq_index, seq_id in enumerate(seq_ids):
last_token_id = accepted_token_ids_by_step[-1][seq_index]
if last_token_id == -1:
self._seq_with_bonus_token_in_last_step.discard(seq_id)
else:
self._seq_with_bonus_token_in_last_step.add(seq_id)
for request_id, sequences in request_ids_seq_ids_mapping.items():
self._request_id_seq_id_mapping[request_id].update(sequences)
@cached_property
def _vocab_size(self) -> int:
"""Get the vocab size of the model and make sure it's consistent between
draft and target workers.
"""
vocab_sizes = [
worker.vocab_size
for worker in [self.proposer_worker, self.scorer_worker]
]
assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
return vocab_sizes[0]
@property
def rank(self):
return self.scorer_worker.rank
@property
def device(self):
return self.scorer_worker.device
@property
def _driver_rank(self) -> int:
return 0
def get_cache_block_size_bytes(self):
"""Return the size of a cache block in bytes.
This function is only used to compose workers within a SpecDecodeWorker.
We leave composing a SpecDecodeWorker within a SpecDecodeWorker
undefined for now, although it could be implemented in the future.
See https://arxiv.org/abs/2308.04623.
"""
raise NotImplementedError
def start_profile(self):
if isinstance(self.scorer_worker, WorkerBase):
self.scorer_worker.start_profile()
def stop_profile(self):
if isinstance(self.scorer_worker, WorkerBase):
self.scorer_worker.stop_profile()
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
proposer_cache_block_size_bytes: int,
total_num_gpu_blocks: int) -> int:
"""Given total_num_gpu_blocks, the number of GPU blocks that could be
allocate to the target model, this function calculates how many blocks
should be given to the draft and target model.
Note that usually the block size, in bytes, of each model is different,
as it's a function of number of KV/layer, number of heads, and hidden
dimension size.
Since the target and draft models allocate the same number of blocks, we
simply calculate the number of blocks where if allocated by both models,
the total memory usage from KV cache is no larger than the number of
blocks allocatable by the target model alone.
"""
new_num_gpu_blocks = int(
total_num_gpu_blocks * scorer_cache_block_size_bytes /
(proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
return new_num_gpu_blocks
def prepare_prefill_hidden_states(
prefill_hidden_states: torch.Tensor) -> HiddenStates:
# For prefill step in proposer, we run the model for N-1 tokens
# because Nth token will be processed in the first decode step. For
# N-1 tokens, the input should be 0:N-1 hidden states which should
# be concatanated with 1:N token (since output of scorer has to be
# the input for proposer). Therefore, we shift the hidden states to
# align n-1th hidden state with nth token.
return HiddenStates(prefill_hidden_states.roll(
shifts=1, dims=0)) if prefill_hidden_states is not None else None
from typing import List, Optional, Set, Tuple, Any, Dict
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (ExecuteModelRequest,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.util import sampler_output_to_torch
class TreeStyleProposer(SpeculativeProposer):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def __init__(
self,
worker: ProposerWorkerBase,
device: str,
vocab_size: int,
tree_buffers: Dict[str, Any],
max_proposal_len: Optional[int] = None,
):
self._worker = worker
self._device = device
self.tree_buffers = tree_buffers
self.max_proposal_len = max_proposal_len
self._vocab_size = vocab_size
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
#proposal_len = execute_model_req.num_lookahead_slots
proposal_len = self.tree_buffers["tree_indices"].shape[0]
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
(
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
# If sampler_transposed is true, then maybe_sampler_output's
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
hidden_states = execute_model_req.previous_hidden_states
if hidden_states is not None:
hidden_states.prune(nonzero_proposal_len_seqs)
logits = execute_model_req.previous_logits
if logits is not None:
logits.prune(nonzero_proposal_len_seqs)
nonzero_execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=nonzero_proposal_len_seqs,
num_lookahead_slots=proposal_len,
previous_hidden_states=hidden_states,
previous_logits=logits,
)
maybe_sampler_output, transposed = self._worker.sampler_output(
execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len,
seq_ids_with_bonus_token_in_last_step=\
seq_ids_with_bonus_token_in_last_step,
)
(
proposal_lens,
maybe_sampler_output,
nonzero_proposal_len_indices,
) = self._remove_no_proposal_seqs(proposal_lens,
maybe_sampler_output,
nonzero_proposal_len_indices,
transposed)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
transposed = False
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens, cart_candidates, tree_attn_masks = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
proposal_len=proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
sampler_transposed=transposed,
)
tree_position_ids_list = []
for seq_group_metadata in seq_group_metadata_list:
seq_data = next(iter(seq_group_metadata.seq_data.values()))
if seq_data.get_first_step_flag():
seq_len = seq_data.get_len() -1
else:
seq_len = seq_data.get_len()
tree_position_ids = self.tree_buffers['tree_position_ids'] + seq_len
tree_position_ids_list.append(tree_position_ids)
tree_position_ids = torch.stack(tree_position_ids_list, dim=0).reshape(-1, 1)
proposals = SpeculativeProposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
no_proposals=maybe_sampler_output is None,
cart_candidates=cart_candidates,
retrieve_indices=self.tree_buffers['retrieve_indices'],
tree_attn_masks=tree_attn_masks,
tree_position_ids=tree_position_ids)
return proposals
def _split_by_proposal_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Split sequences by two groups:
1. Sequences with non-zero proposal length.
2. Sequences with zero proposal length (due to disabled speculation
or exceed the maximum model length).
"""
proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
# The speculative decoding for this request has been disabled
# (e.g. due to high traffic).
if seq_group_metadata.num_speculative_tokens == 0:
proposal_lens.append(0)
continue
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall no exceed this
# quota for nonzero_proposal
new_k = 0
if (self.max_proposal_len is None
or seq_len + proposal_len < self.max_proposal_len):
new_k = proposal_len
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
proposal_lens.append(new_k)
seq_group_metadata.num_speculative_tokens = new_k
return (
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
)
@staticmethod
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices, transposed):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
(maybe_sampler_output=None). This can avoid scoring overheads.
"""
# If maybe_sampler_output is None, then the draft worker did not
# provide a proposal for any sequence and thus no action needed.
# Also we do not support transposed maybe_sampler_output for now
# because it seems not straightforward for draft workers outputting
# transposed sampler outputs to handle the case of no proposal.
if maybe_sampler_output is None or transposed:
return (proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices)
new_proposal_lens: List[int] = []
new_nonzero_proposal_len_indices: List[int] = []
new_maybe_sampler_output: List[SamplerOutput] = []
nonzero_proposal_len_idx_ptr = 0
seq_idx = 0
while seq_idx < len(
proposal_lens) and nonzero_proposal_len_idx_ptr < len(
nonzero_proposal_len_indices):
if seq_idx < nonzero_proposal_len_indices[
nonzero_proposal_len_idx_ptr]:
# Sequence is not in the original nonzero_proposal_len_indices,
# meaning that it has a proposal length of 0 before sending to
# the draft worker.
assert proposal_lens[seq_idx] == 0
new_proposal_lens.append(0)
else:
# Sequence is in the original nonzero_proposal_len_indices
if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
# but does not have a proposal from the draft worker.
new_proposal_lens.append(0)
else:
# and has a proposal from the draft worker. Add it to the
# new nonzero proposal list and keep the sampler output.
new_proposal_lens.append(proposal_lens[seq_idx])
new_nonzero_proposal_len_indices.append(seq_idx)
new_maybe_sampler_output.append(
maybe_sampler_output[nonzero_proposal_len_idx_ptr])
nonzero_proposal_len_idx_ptr += 1
seq_idx += 1
# The remaining sequences should have proposal length of 0.
new_proposal_lens.extend(proposal_lens[seq_idx:])
# We assume sampler_output will not be a list of all Nones.
# In this case this function should not be called.
assert new_maybe_sampler_output
return (new_proposal_lens, new_maybe_sampler_output,
new_nonzero_proposal_len_indices)
def _merge_outputs(
self,
batch_size: int,
proposal_len: int,
maybe_sampler_output: Optional[List[SamplerOutput]],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
retrieve_indices = self.tree_buffers["retrieve_indices"]
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.tensor(-1,
dtype=torch.long,
device=self._device).expand(
batch_size, proposal_len)
proposal_probs = torch.tensor(0,
dtype=torch.float32,
device=self._device).expand(
batch_size, proposal_len,
self._vocab_size)
proposal_lens_tensor = torch.tensor(0,
dtype=torch.long,
device=self._device).expand(
len(proposal_lens))
cart_candidates_tensor = torch.tensor(0,
dtype=torch.long,
device=self._device).expand(
batch_size, retrieve_indices.shape[0],
retrieve_indices.shape[1])
tree_attn_masks_tensor = torch.tensor(0,
dtype=torch.int32,
device=self._device).expand(
batch_size, self.tree_buffers["tree_attn_masks"].shape[0],
self.tree_buffers["tree_attn_masks"].shape[1])
return proposal_tokens, proposal_probs, proposal_lens_tensor, cart_candidates_tensor, tree_attn_masks_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs, _, _, cart_candidates, tree_attn_masks = sampler_output_to_torch(
sampler_output, sampler_transposed)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = proposal_tokens.new_full(
size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1,
)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = None
proposal_tokens, proposal_probs = (
entire_proposal_tokens,
entire_proposal_probs,
)
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
entire_cart_candidates = cart_candidates.new_zeros(
batch_size,
*cart_candidates.shape[1:],
)
entire_cart_candidates[nonzero_proposal_len_indices] = cart_candidates
entire_tree_attn_masks = tree_attn_masks.new_zeros(
batch_size,
*tree_attn_masks.shape[1:],
)
entire_tree_attn_masks[nonzero_proposal_len_indices] = tree_attn_masks
entire_tree_attn_masks = entire_tree_attn_masks.reshape(-1, tree_attn_masks.shape[-1])
return proposal_tokens, proposal_probs, proposal_lens_tensor, entire_cart_candidates, entire_tree_attn_masks
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from contextlib import contextmanager
from typing import Dict, List, Optional, Sequence, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SequenceGroupMetadata,
SequenceOutput)
SeqId = int
def get_all_num_logprobs(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
If the sampling params do not call for any logprobs, return 0 for that
sequence.
"""
all_num_logprobs: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
num_logprobs = seq_group_metadata.sampling_params.logprobs
if num_logprobs is None:
num_logprobs = 0
all_num_logprobs.append(num_logprobs)
return all_num_logprobs
def get_sampled_token_logprobs(
# shape [num_steps, batch_size, vocab_size]
logprob_tensor: torch.Tensor,
sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
"""
num_steps, batch_size, vocab_size = logprob_tensor.shape
selected_logprobs = logprob_tensor[
torch.arange(num_steps).unsqueeze(1),
torch.arange(batch_size),
sampled_token_ids,
]
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
-1, -1, vocab_size)
sampled_token_ids_ranks = (logprob_tensor
> expanded_selected_logprobs).sum(-1).add_(1)
return sampled_token_ids_ranks, selected_logprobs
def create_logprobs_output(
token_id: int,
token_id_logprob_rank: int,
token_id_logprob: float,
topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]],
) -> Dict[int, Logprob]:
"""Create a Logprob Dict for a token given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs: Dict[int, Logprob] = {
token_id: Logprob(
logprob=token_id_logprob,
rank=token_id_logprob_rank,
),
}
logprobs.update({
topk_token_id: Logprob(
logprob=topk_logprob if topk_logprob is not None else 0.0,
rank=topk_index + 1,
)
for topk_index, (topk_token_id, topk_logprob) \
in enumerate(zip(topk_token_ids, topk_logprobs)) \
if topk_token_id is not None
})
return logprobs
def create_sequence_group_output(
token_id: int,
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]],
prompt_logprobs: Optional[PromptLogprobs] = None,
step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
step_index: (Optional[int]): The index of the speculative token.
"""
logprobs = create_logprobs_output(
token_id,
token_id_logprob_rank,
token_id_logprob,
topk_token_ids,
topk_logprobs,
)
return CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs=logprobs)
],
prompt_logprobs=prompt_logprobs,
step_index=step_index)
def split_batch_by_proposal_len(
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_lens: List[int],
) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[
List[SequenceGroupMetadata], List[int]]]:
"""Utility function that splits a batch based on whether the proposal len is
zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch.
"""
nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
for i, (seq_group, proposal_len) in enumerate(
zip(seq_group_metadata_list, proposal_lens)):
seq_groups, indices = nonzero_lists if proposal_len else zero_lists
seq_groups.append(seq_group)
indices.append(i)
return nonzero_lists, zero_lists
def sampler_output_to_torch(
sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
we need do additional tensor transpose logic here.
Returns:
sampled_token_ids: torch.Tensor
shape: [batch_size, len(sampler_output_list)]
sampled_token_probs: torch.Tensor
shape: [batch_size, len(sampler_output_list), vocab_size]
"""
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_probs = None
if sampler_output_list[0].sampled_token_probs is not None:
sampled_token_probs = torch.stack(
[
sampler_output.sampled_token_probs
for sampler_output in sampler_output_list
],
dim=0,
)
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs = None
if sampler_output_list[0].logprobs is not None:
sampled_token_logprobs = torch.stack(
[sampler_output.logprobs for sampler_output in sampler_output_list],
dim=0,
)
# shape: [batch_size, num_sampler_output]
sampled_token_ids = torch.stack(
[
sampler_output.sampled_token_ids.flatten()
for sampler_output in sampler_output_list
],
dim=0,
)
if sampler_transposed:
sampled_token_probs = sampled_token_probs.transpose(0, 1)
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
sampled_token_ids = sampled_token_ids.transpose(0, 1)
if sampler_output_list[0].hidden_states is not None:
# shape: [batch_size, num_sampler_output, hidden_dim]
sampled_hidden_states = torch.stack(
[
sampler_output.hidden_states
for sampler_output in sampler_output_list
],
dim=0,
)
if sampler_transposed:
sampled_hidden_states = sampled_hidden_states.transpose(0, 1)
else:
sampled_hidden_states = None
sampled_cart_candidates = None
if sampler_output_list[0].cart_candidates is not None:
sampled_cart_candidates = torch.cat(
[
sampler_output.cart_candidates
for sampler_output in sampler_output_list
],
dim=0,
)
if sampler_transposed:
sampled_cart_candidates = sampled_cart_candidates.transpose(0, 1)
sampled_tree_attn_masks = None
if sampler_output_list[0].tree_attn_masks is not None:
sampled_tree_attn_masks = torch.stack(
[
sampler_output.tree_attn_masks
for sampler_output in sampler_output_list
],
dim=0,
)
if sampler_transposed:
sampled_tree_attn_masks = sampled_tree_attn_masks.transpose(0, 1)
return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
sampled_hidden_states, sampled_cart_candidates, sampled_tree_attn_masks)
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
vocab_size: int, device: str) -> None:
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
values. This will be removed in PR 7/9.
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
"""
values = [
sampler_output.sampled_token_probs, sampler_output.sampled_token_ids
]
assert all(v is None for v in values) or not any(v is None for v in values)
if not any(v is None for v in values):
# Do nothing if the tensors are already created (usually in unit tests).
return
# Softmax to ensure valid probs.
sampler_output.sampled_token_probs = torch.nn.functional.softmax(
torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device),
dim=-1)
sampler_output.sampled_token_ids = torch.randint(low=10,
high=100,
size=(batch_size, ),
dtype=torch.long,
device=device)
@contextmanager
def nvtx_range(msg, *args, **kwargs):
"""
Context manager / decorator that pushes an NVTX range at the beginning
of its scope, and pops it at the end. If extra arguments are given,
they are passed as arguments to msg.format().
If running with cuda graphs, you must enable nsys cuda graph profiling.
Arguments:
msg (string): message to associate with the range
"""
if current_platform.is_cuda_alike():
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
try:
yield
finally:
torch.cuda.nvtx.range_pop()
else:
yield
class Timer:
"""Basic timer context manager for measuring CPU time.
"""
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time_s = self.end_time - self.start_time
self.elapsed_time_ms = self.elapsed_time_s * 1000
# SPDX-License-Identifier: Apache-2.0
import os
from triton.runtime.cache import (FileCacheManager, default_cache_dir,
default_dump_dir, default_override_dir)
from vllm.logger import init_logger
logger = init_logger(__name__)
def maybe_set_triton_cache_manager() -> None:
"""Set environment variable to tell Triton to use a
custom cache manager"""
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger is None:
manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager"
logger.info("Setting Triton cache manager to: %s", manager)
os.environ["TRITON_CACHE_MANAGER"] = manager
class CustomCacheManager(FileCacheManager):
"""Re-implements Triton's cache manager, ensuring that a
unique cache directory is created for each process. This is
needed to avoid collisions when running with tp>1 and
using multi-processing as the distributed backend.
Note this issue was fixed by triton-lang/triton/pull/4295,
but the fix is not yet included in triton==v3.0.0. However,
it should be included in the subsequent version.
"""
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv("TRITON_CACHE_DIR",
"").strip() or default_cache_dir()
if self.cache_dir:
# self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")
...@@ -161,7 +161,7 @@ curr_o, curr_lse = scaled_dot_product_attention( ...@@ -161,7 +161,7 @@ curr_o, curr_lse = scaled_dot_product_attention(
for chunk_idx in range(cdiv(C, MCC)): for chunk_idx in range(cdiv(C, MCC)):
chunk_start = chunk_idx * MCC chunk_start = chunk_idx * MCC
chunk_end = min(chunk_start + MCC, C) chunk_end = min(chunk_start + MCC, C)
Sc = chunk_end - chunk_start_table Sc = chunk_end - chunk_start
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
......
...@@ -45,10 +45,8 @@ class CommonAttentionMetadata: ...@@ -45,10 +45,8 @@ class CommonAttentionMetadata:
seq_lens_cpu: torch.Tensor seq_lens_cpu: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens """(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens""" and newly scheduled tokens"""
num_computed_tokens_cpu: torch.Tensor num_computed_tokens_cpu: torch.Tensor
"""(batch_size,), the number of computed tokens for each request""" """(batch_size,), the number of computed tokens for each request"""
num_reqs: int num_reqs: int
"""Number of requests""" """Number of requests"""
num_actual_tokens: int num_actual_tokens: int
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
import torch
from vllm.attention import AttentionMetadata
from vllm.forward_context import set_forward_context
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase,
ModelInputForCPUBuilder,
ModelInputForCPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens: Optional[torch.Tensor] = None
encoder_input_positions: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "EncoderDecoderModelInputForCPU":
return cast(
EncoderDecoderModelInputForCPU,
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
class CPUEncoderDecoderModelRunner(
CPUModelRunnerBase[EncoderDecoderModelInputForCPU]):
_model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
EncoderDecoderModelInputForCPU)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
def _list_to_int32_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.int32, device=self.device)
def _list_to_long_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.long, device=self.device)
def _empty_int32_tensor(self) -> torch.Tensor:
return self._list_to_int32_tensor([])
def _empty_long_tensor(self) -> torch.Tensor:
return self._list_to_long_tensor([])
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str,
Any]) -> EncoderDecoderModelInputForCPU:
return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> EncoderDecoderModelInputForCPU:
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
(
attn_metadata,
encoder_input_tokens_tensor,
encoder_input_positions_tensor,
) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
model_input)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators)
return dataclasses.replace(
model_input,
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata,
encoder_input_tokens=encoder_input_tokens_tensor,
encoder_input_positions=encoder_input_positions_tensor,
virtual_engine=virtual_engine,
)
def _prepare_encoder_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: EncoderDecoderModelInputForCPU,
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if len(seq_group_metadata_list) == 0:
return (model_input.attn_metadata, None, None)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt = seq_group_metadata_list[0].is_prompt
# Build encoder inputs
encoder_seq_lens: List[int] = []
if is_prompt:
# Prefill phase.
cross_block_tables = self._empty_int32_tensor().view(
len(seq_group_metadata_list), -1)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens,
encoder_input_positions,
cross_slot_mapping,
) = (
[],
[],
[],
)
for seq_group_metadata in seq_group_metadata_list:
# Build seq lens
seq_len = seq_group_metadata.encoder_seq_data.get_len()
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
encoder_seq_lens.append(seq_len)
# Build slot mapping
for i in range(0, seq_len):
block_number = seq_group_metadata.cross_block_table[
i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
cross_slot_mapping.append(slot)
# Build encoder input tokens
encoder_input_tokens.extend(token_ids)
encoder_input_positions.extend(list(range(0, seq_len)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor = self._list_to_long_tensor(
encoder_input_tokens)
encoder_input_positions_tensor = self._list_to_long_tensor(
encoder_input_positions)
cross_slot_mapping_tensor = self._list_to_long_tensor(
cross_slot_mapping)
else:
# Decode phase.
encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor()
cross_slot_mapping_tensor = self._empty_long_tensor()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list:
for _ in range(len(seq_group_metadata.seq_data)):
encoder_seq_lens.append(
seq_group_metadata.encoder_seq_data.get_len())
cross_block_table = seq_group_metadata.cross_block_table
cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
max_len_of_block_table = max(
len(block_table) for block_table in cross_block_tables)
cross_block_tables = make_tensor_with_pad(
cross_block_tables,
max_len=max_len_of_block_table,
pad=0,
dtype=torch.int32,
device=self.device,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len = max(encoder_seq_lens, default=0)
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
1,
dtype=torch.int32,
device=self.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])
# Update attention metadata with encoder-oriented attributes
attn_metadata = model_input.attn_metadata
assert attn_metadata is not None
(
attn_metadata.num_encoder_tokens,
attn_metadata.encoder_seq_lens,
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_slot_mapping,
attn_metadata.cross_block_tables,
) = (
sum(encoder_seq_lens),
encoder_seq_lens,
encoder_seq_lens_tensor,
max_encoder_seq_len,
cross_slot_mapping_tensor,
cross_block_tables,
)
return (attn_metadata, encoder_input_tokens_tensor,
encoder_input_positions_tensor)
@torch.no_grad()
def execute_model(
self,
model_input: EncoderDecoderModelInputForCPU,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
model_executable = self.model
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"encoder_input_ids":
model_input.encoder_input_tokens,
"encoder_positions":
model_input.encoder_input_positions,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
"intermediate_tensors":
intermediate_tensors,
}
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return [output]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type,
TypeVar, Union)
import torch
from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs,
MultiModalPlaceholderMap)
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
_PAD_SLOT_ID = -1
@dataclass(frozen=True)
class ModelInputForCPU(ModelRunnerInputBase):
"""
Base class contains metadata needed for the base model forward pass on CPU
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
token_type_ids: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"token_type_ids": self.token_type_ids,
"multi_modal_kwargs": self.multi_modal_kwargs,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type[TModelInputForCPU],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None
) -> TModelInputForCPU:
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
@dataclass(frozen=True)
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
is_prompt: Optional[bool] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"token_type_ids": self.token_type_ids,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForCPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
class ModelInputData:
def __init__(self, use_mrope: bool):
self.use_mrope = use_mrope
self.input_tokens: List[int] = []
self.input_positions: List[int] = []
self.token_type_ids: Optional[List[int]] = []
self.seq_lens: List[int] = []
self.query_lens: List[int] = []
self.prefill_block_tables: List[List[int]] = []
self.decode_block_tables: List[List[int]] = []
self.max_decode_seq_len: int = 0
self.num_prefills: int = 0
self.num_prefill_tokens: int = 0
self.num_decode_tokens: int = 0
self.slot_mapping: List[int] = []
self.multi_modal_inputs_list: List[MultiModalKwargs] = []
self.multi_modal_placeholder_maps: Dict[
str, MultiModalPlaceholderMap] = defaultdict(
MultiModalPlaceholderMap)
self.input_mrope_positions: List[List[int]] = [[]
for _ in range(3)]
def __init__(self,
runner: "CPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.runner = runner
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
or runner.cache_config.enable_prefix_caching)
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device
self.enable_lora = self.runner.lora_config is not None
if self.runner.attn_backend is not None:
# spec decode (e.g. Medusa) does not have atten backend
attn_backend = self.runner.attn_backend
self.att_metadata_builder = attn_backend.get_builder_cls()(self)
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.input_data = ModelInputForCPUBuilder.ModelInputData(
self.runner.model_config.uses_mrope)
self.att_metadata_builder.prepare()
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
def set_seq_group_list(
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
self.seq_group_metadata_list = seq_group_metadata_list
def build(self) -> ModelInputForCPU:
self._build_input_data()
input_data = self.input_data
input_tokens = torch.tensor(input_data.input_tokens,
dtype=torch.long,
device="cpu")
input_positions = torch.tensor(
input_data.input_positions
if not any(input_data.input_mrope_positions) else
input_data.input_mrope_positions,
dtype=torch.long,
device="cpu")
token_type_ids = torch.tensor(input_data.token_type_ids,
dtype=torch.long,
device="cpu") \
if input_data.token_type_ids else None
# For multi-modal models
multi_modal_kwargs = None
if len(input_data.multi_modal_inputs_list) != 0:
multi_modal_kwargs = MultiModalKwargs.batch(
input_data.multi_modal_inputs_list)
attn_metadata = self.att_metadata_builder.build(
input_data.seq_lens, input_data.query_lens, -1, -1)
is_prompt = (self.seq_group_metadata_list[0].is_prompt
if self.seq_group_metadata_list else None)
# LoRA data.
lora_requests = set()
lora_mapping = None
if self.enable_lora:
lora_requests = set(seq.lora_request
for seq in self.seq_group_metadata_list
if seq.lora_request is not None)
lora_mapping = self._prepare_lora_input(
self.seq_group_metadata_list, is_prompt)
return self.model_input_cls(input_tokens=input_tokens,
input_positions=input_positions,
token_type_ids=token_type_ids,
seq_lens=input_data.seq_lens,
query_lens=input_data.query_lens,
attn_metadata=attn_metadata,
multi_modal_kwargs=multi_modal_kwargs,
lora_mapping=lora_mapping,
lora_requests=lora_requests)
def _build_input_data(self):
for seq_group_metadata in self.seq_group_metadata_list:
for seq_id, seq_data in seq_group_metadata.seq_data.items():
if seq_group_metadata.is_prompt:
self._compute_prompt_input_tokens(self.input_data,
seq_group_metadata,
seq_data, seq_id)
if seq_group_metadata.multi_modal_data:
self._compute_multi_modal_input(
seq_group_metadata, seq_data)
else:
self._compute_decode_input_tokens(self.input_data,
seq_group_metadata,
seq_data, seq_id)
def _compute_decode_input_tokens(self, data: ModelInputData,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData, seq_id: int):
"""
Compute decode input tokens, positions, block table and slot mapping.
"""
block_size = self.runner.block_size
block_table = seq_group_metadata.block_tables[seq_id]
seq_len = seq_data.get_len()
context_len = seq_data.get_num_computed_tokens()
tokens = seq_data.get_last_token_id()
token_positions = seq_len - 1
block_number = block_table[token_positions // block_size]
block_offset = token_positions % block_size
slot = block_number * block_size + block_offset
# For paged_attention kernel
if self.runner.sliding_window:
start_idx = max(0, seq_len - self.runner.sliding_window)
start_block = start_idx // block_size
start_idx = start_block * block_size
seq_len = seq_len - start_idx
block_table = block_table[start_block:]
# For MRotaryEmbedding
if seq_data.mrope_position_delta is not None:
next_pos = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta,
context_len,
seq_len,
)
for idx in range(3):
data.input_mrope_positions[idx].extend( # type: ignore
next_pos[idx])
else:
data.input_positions.append(token_positions) # type: ignore
# Update fields
data.input_tokens.append(tokens)
data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
data.num_decode_tokens += 1
data.slot_mapping.append(slot)
data.decode_block_tables.append(block_table)
data.query_lens.append(1)
data.seq_lens.append(seq_len)
def _compute_prompt_input_tokens(self, data: ModelInputData,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData, seq_id: int):
"""
Compute prompt input tokens, positions, block table and slot mapping.
"""
token_chunk_size = seq_group_metadata.token_chunk_size
block_size = self.runner.block_size
block_table = seq_group_metadata.block_tables[seq_id]
seq_len = seq_data.get_len()
context_len = seq_data.get_num_computed_tokens()
seq_len = min(seq_len, context_len + token_chunk_size)
# For prefix caching
prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
if prefix_cache_block_num > 0:
prefix_cache_len = (prefix_cache_block_num *
self.runner.block_size)
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass
elif context_len < prefix_cache_len < seq_len:
# Partial hit. Compute the missing part.
context_len = prefix_cache_len
token_chunk_size = seq_len - context_len
elif seq_len <= prefix_cache_len:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
context_len = seq_len - 1
token_chunk_size = 1
tokens = seq_data.get_token_ids()
tokens = tokens[context_len:seq_len]
token_positions = range(context_len, seq_len)
token_types = seq_group_metadata.token_type_ids
# For encoder-only models, the block_table is None,
# and there is no need to initialize the slot_mapping.
if block_table is not None:
slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
for i, pos in enumerate(token_positions):
block_number = block_table[pos // block_size]
block_offset = pos % block_size
slot = block_number * block_size + block_offset
slot_mapping[i] = slot
data.slot_mapping.extend(slot_mapping)
# The MROPE positions are prepared in _compute_multi_modal_input
data.input_positions.extend(token_positions)
if data.token_type_ids is not None:
data.token_type_ids.extend(token_types if token_types else [])
# Update fields
data.input_tokens.extend(tokens)
data.num_prefills += 1
data.num_prefill_tokens += len(tokens)
data.query_lens.append(len(tokens))
data.prefill_block_tables.append(block_table)
data.seq_lens.append(seq_len)
def _compute_multi_modal_input(self,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData):
computed_len = seq_data.get_num_computed_tokens()
seq_len = self.input_data.seq_lens[-1]
# NOTE: mm_kwargs only includes the subset of multi-modal items that
# intersect with the current prefill positions.
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group_metadata, range(computed_len, seq_len))
if not mm_kwargs:
return
# special processing for mrope position deltas.
if self.runner.model_config.uses_mrope:
assert not self.chunked_prefill, \
"MROPE on CPU does not support chunked-prefill."
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
audio_feature_lengths = mm_kwargs.get("audio_feature_lengths",
None)
assert (
image_grid_thw is not None or video_grid_thw is not None
or audio_feature_lengths is not None), (
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw' or "
"'audio_feature_lengths'.")
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
hf_config = self.runner.model_config.hf_config
token_ids = seq_data.get_token_ids()
mrope_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=computed_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
seq_data.mrope_position_delta = mrope_position_delta
for i in range(3):
self.input_data.input_mrope_positions[ # type: ignore
i].extend(mrope_positions[i])
self.input_data.multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
self.input_data.multi_modal_placeholder_maps[modality].extend(
placeholder_map)
def _prepare_lora_input(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
is_prefill: bool) -> LoRAMapping:
index_mapping = []
prompt_mapping = []
for seq in seq_group_metadata_list:
lora_id = seq.lora_int_id
query_len = seq.token_chunk_size
index_mapping += [lora_id] * query_len
prompt_mapping += [lora_id] * (
query_len if seq.sampling_params
and seq.sampling_params.prompt_logprobs is not None else 1)
return LoRAMapping(index_mapping=tuple(index_mapping),
prompt_mapping=tuple(prompt_mapping),
is_prefill=is_prefill)
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
"""
Helper class for shared methods between CPU model runners.
"""
_model_input_cls: Type[TModelInputForCPU]
_builder_cls: Type[ModelInputForCPUBuilder]
builder: ModelInputForCPUBuilder
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
*args,
**kwargs,
):
ModelRunnerBase.__init__(self, vllm_config)
model_config = self.model_config
cache_config = self.cache_config
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device
self.pin_memory = False
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
needs_attn_backend = (num_attn_heads != 0
or self.model_config.is_attention_free)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
) if needs_attn_backend else None
# Lazy initialization.
self.model: nn.Module # Set after init_Model
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.sampler = get_sampler()
if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
assert supports_lora(
self.model
), f"{self.model.__class__.__name__} does not support LoRA yet."
if supports_multimodal(self.model):
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# Use get_text_config() in case of multimodal models
text_config = self.model_config.hf_config.get_text_config()
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=text_config.max_position_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
def get_model(self) -> nn.Module:
return self.model
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None
) -> TModelInputForCPU:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
self.builder.prepare(finished_requests_ids)
self.builder.set_seq_group_list(seq_group_metadata_list)
return self.builder.build() # type: ignore
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
def remove_all_loras(self):
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.remove_all_adapters()
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
ModelInputForCPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForCPUWithSamplingMetadata:
return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators)
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
virtual_engine=virtual_engine,
is_prompt=is_prompt)
@torch.no_grad()
def execute_model(
self,
model_input: ModelInputForCPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
previous_hidden_states: Optional[torch.Tensor] = None,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
model_executable = self.model
multimodal_kwargs = {}
if model_input.multi_modal_kwargs is not None:
multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs,
device=self.device,
)
execute_model_kwargs = {}
if previous_hidden_states is not None:
execute_model_kwargs.update(
{"previous_hidden_states": previous_hidden_states})
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**execute_model_kwargs,
**multimodal_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
if model_input.is_prompt:
output.prefill_hidden_states = hidden_states
output.hidden_states = hidden_states
return [output]
def generate_proposals(self, *args, **kwargs):
return self.model.generate_proposals(*args, **kwargs)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
from vllm.forward_context import set_forward_context
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
ModelInputForCPUBuilder)
@dataclasses.dataclass(frozen=True)
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
"""
Used by the CPUPoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class CPUPoolingModelRunner(
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
ModelInputForCPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForCPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
model_executable = self.model
cross_enc_kwargs = {}
if model_input.token_type_ids is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
**cross_enc_kwargs,
"intermediate_tensors":
intermediate_tensors,
}
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker.
if not self.is_driver_worker:
return []
return [
self.model.pooler(hidden_states=hidden_states,
pooling_metadata=model_input.pooling_metadata)
]
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForCPUWithPoolingMetadata:
return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
model_input.seq_lens)
return dataclasses.replace(model_input,
virtual_engine=virtual_engine,
pooling_metadata=pooling_metadata)
def _prepare_pooling(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> PoolingMetadata:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
pooling_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
)
return pooling_metadata
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A CPU worker class."""
import os
from importlib import util
from typing import List, Optional, Set, Tuple, Type
import torch
import torch.distributed
import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.worker.cache_engine import CacheEngine
from vllm.sequence import ExecuteModelRequest
from vllm.utils import bind_kv_cache
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class CPUCacheEngine:
"""Manages the KV cache for CPU backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig) -> None:
assert device_config.device_type == "cpu"
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for CPU backend, because we want to reuse KV cache management
# in the scheduler.
self.num_cpu_blocks = cache_config.num_gpu_blocks
self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config,
model_config)
# Get attention backend.
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
# Initialize the cache.
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
def _allocate_kv_cache(
self,
num_blocks: int,
) -> List[torch.Tensor]:
"""Allocates KV cache on CPU."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_heads, self.head_size)
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
kv_cache.append(
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
return kv_cache
def swap_in(self, src_to_dst: torch.Tensor) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def swap_out(self, src_to_dst: torch.Tensor) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
@staticmethod
def get_kv_cache_dtype(cache_config: CacheConfig,
model_config: ModelConfig):
if cache_config.cache_dtype == "auto":
return model_config.dtype
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
return torch.float8_e5m2
else:
raise NotImplementedError(f"Unsupported KV cache type "
f"{cache_config.cache_dtype}.")
@staticmethod
def get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block if not model_config.use_mla else 0
total = num_layers * (key_cache_block + value_cache_block)
dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config)
dtype_size = torch.tensor([], dtype=dtype).element_size()
return dtype_size * total
class CPUWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is
responsible for maintaining the KV cache and executing the model on the
CPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[CPUModelRunner]] = None,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.local_rank = local_rank
self.rank = rank
vllm_config.parallel_config.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Setup OpenMP threads affinity.
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
self.local_omp_cpuid = "all"
if omp_cpuids == "auto":
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
)
else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
if self.model_config.runner_type == "pooling":
ModelRunnerClass = CPUPoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
**speculative_args,
)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CPUCacheEngine]
# Initialize cpu_cache as pooling models don't initialize kv_caches
self.cpu_cache: Optional[List[List[torch.Tensor]]] = None
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
def init_device(self) -> None:
if self.local_omp_cpuid != "all":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret:
logger.info(ret)
# Note: unique identifier for creating allreduce shared memory
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
":")[-1]
self.device = torch.device("cpu")
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured CPU
KV cache space.
Note that since vLLM assumes a block resides on GPU if it can be
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
This allows us to reuse the scheduler of vLLM without generalizing it
to different devices.
"""
# For CPU device, the block number will be calculated based on the
# cpu_kvcache_space.
cache_block_size = self.get_cache_block_size_bytes()
num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
cache_block_size)
num_cpu_blocks = max(num_cpu_blocks, 0)
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_gpu_blocks = num_cpu_blocks
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache. Currently, swappable CPU memory is not
supported.
Since this worker does not support GPUs, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
assert (num_cpu_blocks == 0
), f"{type(self)} does not support swappable cache"
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_cpu_blocks = num_gpu_blocks
self._validate_num_cpu_blocks(num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_cpu_blocks
self.cache_config.num_cpu_blocks = 0
# Initialize the cache.
self._init_cache_engine()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
"""Raise errors if the num_cpu_blocks is invalid.
"""
if num_cpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_cpu_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
"initializing the engine.")
def _init_cache_engine(self) -> None:
self.cache_engine = [
CPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.cpu_cache = [
self.cache_engine[ve].cpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.cpu_cache)
self.model_runner.block_size = self.cache_engine[0].block_size
assert all(
self.cpu_cache[ve] is not None
for ve in range(self.parallel_config.pipeline_parallel_size))
# Populate the cache to warmup the memory
for ve in range(self.parallel_config.pipeline_parallel_size):
for layer_cache in self.cpu_cache[ve]:
layer_cache.fill_(0)
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.cpu_cache
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
def execute_worker(
self,
worker_input: WorkerInput,
) -> None:
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[worker_input.virtual_engine].copy(
worker_input.blocks_to_copy)
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None
virtual_engine: int = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu",
dtype=torch.int64).view(-1, 2)
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
def init_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cpu())
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block.
"""
return CPUCacheEngine.get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
"""Return CPUs id binding based on NUMA nodes.
"""
rank_to_cpus = self.local_omp_cpuid
# Setup OpenMP thread affinity based on NUMA nodes automatically
world_size = self.vllm_config.parallel_config.world_size
libnuma_found = util.find_spec("numa") is not None
psutil_found = util.find_spec("psutil") is not None
if libnuma_found and psutil_found:
import psutil
from numa import info
cpu_count = psutil.cpu_count(logical=False)
cpus_allow_list = psutil.Process().cpu_affinity()
numa_size = info.get_num_configured_nodes()
cpu_count_per_numa = cpu_count // numa_size
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
cpu_count_per_numa // 2)
# check allow node_to_cpus list
node_to_cpus = []
for i in range(numa_size):
node_intersect = set(
info.node_to_cpus(i)).intersection(cpus_allow_list)
if bool(node_intersect):
node_to_cpus.append(list(node_intersect))
if world_size > len(node_to_cpus):
logger.error(
"Auto thread-binding failed due to "
"world size: %d is larger than "
"allowed NUMA nodes number: %d."
"Please try to bind threads manually.", world_size,
len(node_to_cpus))
else:
end = cpu_count_per_numa - num_of_reserved_cpu
rank_to_cpus_list = node_to_cpus[self.rank][:end]
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
logger.info("auto thread-binding list: %s", rank_to_cpus)
else:
logger.warning(
"Auto thread-binding is not supported due to "
"the lack of package numa and psutil,"
"fallback to no thread-binding. To get better performance,"
"please try to manually bind threads.")
return rank_to_cpus
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Dict, Optional, Tuple
import torch
from vllm.distributed import broadcast_tensor_dict
from vllm.sequence import ExecuteModelRequest
from vllm.worker.tpu_model_runner import ModelInputForTPU
from vllm.worker.tpu_worker import TPUWorker
from vllm.worker.worker_base import WorkerInput
class MultiStepTPUWorker(TPUWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cached_model_input: Optional[ModelInputForTPU] = None
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]:
assert self.is_driver_worker
assert execute_model_req.virtual_engine == 0
is_first_multi_step = execute_model_req.is_first_multi_step
is_last_step = execute_model_req.is_last_step
if is_first_multi_step:
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
worker_input = dataclasses.replace(
worker_input,
num_steps=execute_model_req.num_lookahead_slots + 1)
model_input: ModelInputForTPU = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if execute_model_req.async_callback:
model_input = dataclasses.replace(
model_input,
async_callback=execute_model_req.async_callback)
else:
assert self.cached_model_input is not None
model_input = self.cached_model_input
worker_input = WorkerInput()
model_input = dataclasses.replace(
model_input,
is_first_multi_step=is_first_multi_step,
is_last_step=is_last_step)
if self.do_metadata_broadcast:
if is_first_multi_step:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
else:
broadcast_data = {
"is_first_multi_step": is_first_multi_step,
"is_last_step": is_last_step,
}
broadcast_tensor_dict(broadcast_data, src=0)
# Retuning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return model_input, worker_input, {}
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str,
torch.Tensor]]]:
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
broadcast_tensor_dict({}, src=0)
return None
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
execute_model_req)
if model_input.is_first_multi_step:
self.cached_model_input = model_input
return model_input, worker_input, {}
else:
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
if len(broadcast_data) == 2:
assert self.cached_model_input is not None
self.cached_model_input = dataclasses.replace(
self.cached_model_input,
is_first_multi_step=broadcast_data["is_first_multi_step"],
is_last_step=broadcast_data["is_last_step"])
empty_worker_input = WorkerInput()
return self.cached_model_input, empty_worker_input, {}
worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
self.cached_model_input = model_input
return model_input, worker_input, {}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, Union)
from unittest.mock import patch
import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128
class ExecutionMode(enum.Enum):
PREFILL = enum.auto()
DECODE = enum.auto()
PREFIX_PREFILL = enum.auto()
def is_prefill(self) -> bool:
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
@dataclass(frozen=True)
class ModelInputForTPU(ModelRunnerInputBase):
token_ids: torch.Tensor
position_ids: torch.Tensor
attn_metadata: AttentionMetadata
input_lens: torch.Tensor
t: torch.Tensor
p: torch.Tensor
num_samples: int
n: List[int]
seq_groups: List[List[int]]
is_first_multi_step: bool = True
is_last_step: bool = True
virtual_engine: int = 0
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"token_ids": self.token_ids,
"position_ids": self.position_ids,
"input_lens": self.input_lens,
"t": self.t,
"p": self.p,
"num_samples": self.num_samples,
"n": self.n,
"seq_groups": self.seq_groups,
"is_first_multi_step": self.is_first_multi_step,
"is_last_step": self.is_last_step,
"virtual_engine": self.virtual_engine,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["ModelInputForTPU"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForTPU":
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
def __init__(
self,
vllm_config: VllmConfig,
is_driver_worker: bool = False,
):
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.is_driver_worker = is_driver_worker
self.block_size = self.cache_config.block_size
self.max_num_blocks_per_seq = (self.model_config.max_model_len //
self.block_size)
self.block_tables = np.zeros(
(self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
dtype=np.int32)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
False,
)
self.cached_step_outputs: List[torch.Tensor] = []
smem_size = 512 * 1024
block_table_size = 4 * self.block_tables.size
if block_table_size >= smem_size:
logger.warning(
"The max_model_len (%d) is too large. This may degrade the "
"performance due to the insufficient smem size. Consider "
"setting --max-model-len to a smaller value, like %d.",
self.model_config.max_model_len,
self.model_config.max_model_len /
(block_table_size / smem_size))
def load_model(self) -> None:
self.device = self.device_config.device
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
# process, the ranks can be different from the ranks internally assigned
# by the xm runtime. Therefore, there is a mismatch in the rank
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
# This is not a problem in linear layers because all-reduce is
# rank-agnostic. However, it matters for all-gather as the ranks
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank = xr.global_ordinal()
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
xm.wait_device_ops()
model = ModelWrapper(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
def get_model(self) -> nn.Module:
return self.model.model
def _dummy_run(
self,
batch_size: int,
seq_len: int,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
exec_mode: ExecutionMode,
) -> None:
exec_mode = ExecutionMode(exec_mode)
if exec_mode.is_prefill():
seq_len = (seq_len + 15) // 16 * 16
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((batch_size, seq_len),
dtype=torch.int64,
device=self.device)
input_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
if exec_mode == ExecutionMode.PREFILL:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=None,
context_lens=None,
effective_query_lens=None,
)
else:
context_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
block_tables = torch.tensor(self.block_tables[:batch_size],
dtype=torch.int32,
device=self.device)
effective_query_lens = torch.ones_like(context_lens)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
)
else:
assert seq_len == 1
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((batch_size, seq_len),
dtype=torch.int64,
device=self.device)
block_tables = torch.zeros(
(batch_size, self.max_num_blocks_per_seq),
dtype=torch.int32,
device=self.device)
context_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
input_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size * seq_len,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
)
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if exec_mode.is_prefill():
# Prefll
torch._dynamo.mark_dynamic(token_ids, 1)
torch._dynamo.mark_dynamic(position_ids, 1)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
else:
# Decode
torch._dynamo.mark_dynamic(token_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(input_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
with set_forward_context(attn_metadata, self.vllm_config, 0):
self.model(token_ids, position_ids, input_lens, t, p, num_samples,
kv_caches)
def warmup_model(
self,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> None:
# Prefill
logger.info("Compiling the model with different input shapes...")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.PREFILL)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
num_tokens = batch_size * seq_len
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
break
seq_len = seq_len * 2
end = time.time()
logger.info("Compilation for prefill done in %.2f s.", end - start)
# Prefix prefill
if self.cache_config.enable_prefix_caching:
logger.info("Compiling the model with different input shapes for "
"prefix prefill...")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.PREFIX_PREFILL)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if (num_tokens
>= self.scheduler_config.max_num_batched_tokens):
break
seq_len = seq_len * 2
end = time.time()
logger.info("Compilation for prefix prefill done in %.2f s.",
end - start)
# Decode
start = time.time()
seq_len = 1
batch_size = 8 # Must be in sync with _get_padded_batch_size()
while True:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.DECODE)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
if batch_size >= self.scheduler_config.max_num_seqs:
break
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
end = time.time()
logger.info("Compilation for decode done in %.2f s.", end - start)
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
prompt_lens: List[int] = []
context_lens: List[int] = []
slot_mapping: List[int] = []
for batch_idx, seq_group_metadata in enumerate(
seq_group_metadata_list):
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
# Could include output tokens when a request is preempted.
prompt_tokens = seq_data.get_token_ids()
seq_len = len(prompt_tokens)
num_computed_blocks = len(seq_group_metadata.computed_block_nums)
num_computed_tokens = num_computed_blocks * self.block_size
if num_computed_tokens > 0:
prompt_tokens = prompt_tokens[num_computed_tokens:]
context_lens.append(seq_len)
else:
context_lens.append(0)
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens)
input_positions.extend(range(num_computed_tokens, seq_len))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
for i in range(num_computed_tokens, seq_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if num_computed_tokens > 0:
self.block_tables[batch_idx, :len(block_table)] = block_table
# Add paddings to EACH prompt to the smallest power of 2 that is
# greater than or equal to the prompt length.
# We pad the seq_len to reduce the compilation overhead.
# We execute each prompt individually (i.e., with batch_size 1)
# because the FlashAttention kernel does not support ragged inputs.
# TODO(woosuk): Use SplashAttention to support ragged inputs.
padded_prompt_len = _get_padded_prefill_len(prompt_len)
num_paddings = padded_prompt_len - prompt_len
input_tokens += [0] * num_paddings
input_positions += [0] * num_paddings
slot_mapping += [_PAD_SLOT_ID] * num_paddings
assert len(prompt_lens) > 0
num_prefills = len(prompt_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
device="cpu")
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
prompt_lens = torch.tensor(prompt_lens,
dtype=torch.int32,
device="cpu")
context_lens = torch.tensor(context_lens,
dtype=torch.int32,
device="cpu")
block_tables = torch.tensor(self.block_tables[:num_prefills],
dtype=torch.int32,
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=prompt_lens,
)
return input_tokens, input_positions, attn_metadata, prompt_lens
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
batch_idx = 0
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
context_lens.append(seq_len)
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
self.block_tables[batch_idx, :len(block_table)] = block_table
batch_idx += 1
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings
input_positions = input_positions + [[0]] * num_paddings
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
context_lens = context_lens + [0] * num_paddings
input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
device="cpu")
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
context_lens = torch.tensor(context_lens,
dtype=torch.int32,
device="cpu")
block_tables = torch.tensor(self.block_tables[:batch_size],
dtype=torch.int32,
device="cpu")
input_lens = torch.tensor([1] * batch_size,
dtype=torch.int32,
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
)
return input_tokens, input_positions, attn_metadata, input_lens
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
padded_batch_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
assert len(seq_group_metadata_list) > 0
t = []
p = []
n = []
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
t.append(sampling_params.temperature)
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
raise NotImplementedError(
"Top-p sampling is currently disabled for the TPU backend "
"due to performance issues.")
p.append(sampling_params.top_p)
if sampling_params.top_k > 0:
raise NotImplementedError(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues.")
if sampling_params.n > _MAX_NUM_SAMPLES:
raise NotImplementedError(
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
"backend.")
n.append(sampling_params.n)
if sampling_params.logprobs is not None:
raise NotImplementedError(
"logprobs is not currently supported by the TPU backend.")
if sampling_params.prompt_logprobs is not None:
raise NotImplementedError(
"prompt_logprobs is not currently supported by the TPU "
"backend.")
# Repeat the sampling params if the seq group has multiple seqs.
num_seqs = len(seq_group_metadata.seq_data)
t += [t[-1]] * (num_seqs - 1)
p += [p[-1]] * (num_seqs - 1)
n += [n[-1]] * (num_seqs - 1)
num_paddings = padded_batch_size - len(t)
t += [1.0] * num_paddings
p += [1.0] * num_paddings
t = torch.tensor(t, dtype=torch.float32, device="cpu")
p = torch.tensor(p, dtype=torch.float32, device="cpu")
return t, p, n
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForTPU:
del finished_requests_ids # Unused.
assert virtual_engine == 0
assert len(seq_group_metadata_list) > 0
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
if is_prompt:
inputs = self._prepare_prompt(seq_group_metadata_list)
else:
inputs = self._prepare_decode(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, input_lens = inputs
padded_batch_size = input_tokens.shape[0]
t, p, n = self._prepare_sample(seq_group_metadata_list,
padded_batch_size)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
seq_groups = [
list(metadata.seq_data.keys())
for metadata in seq_group_metadata_list
]
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
input_lens, t, p, num_samples, n, seq_groups)
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=self.attn_backend)
return model_input
@torch.no_grad()
def execute_model(
self,
model_input: ModelInputForTPU,
kv_caches: Optional[List[Any]],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> List[SamplerOutput]:
assert intermediate_tensors is None
if not model_input.is_first_multi_step:
if not model_input.is_last_step:
return []
use_async_out_proc = model_input.async_callback is not None
sampler_outputs = []
num_outputs = len(self.cached_step_outputs)
for i in range(num_outputs):
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = _make_decode_output(next_token_ids,
model_input.seq_groups)
sampler_outputs.append(sampler_output)
if i < num_outputs - 1 and use_async_out_proc:
assert model_input.async_callback is not None
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
ctx.append_output(
outputs=[sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False,
is_first_step_output=i == 0)
model_input.async_callback()
if use_async_out_proc:
return [sampler_outputs[-1]]
else:
return sampler_outputs
is_prompt = model_input.attn_metadata.num_prefills > 0
if is_prompt:
assert num_steps == 1
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
orig_slot_mapping = model_input.attn_metadata.slot_mapping
orig_block_tables = model_input.attn_metadata.block_tables
orig_context_lens = model_input.attn_metadata.context_lens
orig_effective_query_lens = \
model_input.attn_metadata.effective_query_lens
batch_size = model_input.input_lens.shape[0]
start_idx = 0
next_token_ids = []
for i in range(batch_size):
# Get the actual prefill_len.
prefill_len = model_input.input_lens[i:i + 1].item()
prefill_len = _get_padded_prefill_len(prefill_len)
end_idx = start_idx + prefill_len
token_ids = model_input.token_ids[None, start_idx:end_idx].to(
self.device)
position_ids = model_input.position_ids[None,
start_idx:end_idx].to(
self.device)
attn_metadata = model_input.attn_metadata
attn_metadata.num_prefills = 1
attn_metadata.slot_mapping = orig_slot_mapping[
None, start_idx:end_idx].to(self.device)
if orig_context_lens[i].item() > 0:
attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
self.device)
attn_metadata.block_tables = orig_block_tables[
i].unsqueeze(0).to(self.device)
attn_metadata.effective_query_lens = \
orig_effective_query_lens[i:i + 1].to(self.device)
else:
attn_metadata.context_lens = None
attn_metadata.block_tables = None
attn_metadata.effective_query_lens = None
input_lens = model_input.input_lens[i:i + 1].to(self.device)
t = model_input.t[i:i + 1].to(self.device)
p = model_input.p[i:i + 1].to(self.device)
with set_forward_context(model_input.attn_metadata,
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
input_lens, t, p,
model_input.num_samples,
kv_caches)
next_token_ids.append(output_token_ids[0])
start_idx = end_idx
if model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU.
next_token_ids = [
output_token_ids.cpu().tolist()
for output_token_ids in next_token_ids
]
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support advanced sampling parameters such as logprobs.
zero_logprob = Logprob(0.0)
sampler_outputs = []
for i, seq_group in enumerate(model_input.seq_groups):
seq_ids = seq_group
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_outputs = []
for j in range(model_input.n[i]):
next_token_id = next_token_ids[i][j]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return [SamplerOutput(sampler_outputs)]
else:
token_ids = model_input.token_ids.to(self.device)
position_ids = model_input.position_ids.to(self.device)
attn_metadata = model_input.attn_metadata
attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
self.device)
attn_metadata.block_tables = attn_metadata.block_tables.to(
self.device)
attn_metadata.context_lens = attn_metadata.context_lens.to(
self.device)
t = model_input.t.to(self.device)
p = model_input.p.to(self.device)
input_lens = model_input.input_lens.to(self.device)
for i in range(num_steps):
slot_mapping = attn_metadata.slot_mapping
with set_forward_context(model_input.attn_metadata,
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
input_lens, t, p,
model_input.num_samples,
kv_caches)
self.cached_step_outputs.append(output_token_ids)
if i < num_steps - 1:
# Prepare the inputs for the next step.
token_ids = output_token_ids.unsqueeze(dim=1).int()
position_ids = position_ids + 1
attn_metadata.context_lens = attn_metadata.context_lens + 1
block_tables = attn_metadata.block_tables
block_number = block_tables.gather(
1,
position_ids.long() // self.block_size)
block_offset = position_ids % self.block_size
is_padding = slot_mapping == _PAD_SLOT_ID
slot_mapping = block_number * self.block_size + block_offset
slot_mapping = slot_mapping.long()
slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
slot_mapping)
attn_metadata.slot_mapping = slot_mapping
if model_input.async_callback is not None:
model_input.async_callback()
if num_steps > 1:
return []
# Retrieve the outputs to CPU.
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = _make_decode_output(next_token_ids,
model_input.seq_groups)
return [sampler_output]
class ModelWrapper(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
def forward(
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
input_lens: torch.Tensor,
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
batch_size, seq_len = token_ids.shape
# Calculate the positions to sample from.
start_indices = torch.arange(
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
logits_indices = start_indices + input_lens - 1
attn_metadata = get_forward_context().attn_metadata
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.
sampling_metadata = SamplingMetadata(
seq_groups=[],
selected_token_indices=logits_indices,
categorized_sample_indices={},
num_prompts=attn_metadata.num_prefills,
)
# Skip this in memory profiling at initialization.
if kv_caches[0][0].numel() > 0:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten()
head_indices = torch.arange(0,
num_kv_heads,
device=slot_mapping.device,
dtype=slot_mapping.dtype)
head_indices *= block_size * num_blocks
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
-1, num_kv_heads)
slot_mapping = slot_mapping + head_indices.view(1, -1)
slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping
hidden_states = self.model(token_ids, position_ids)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Argmax sampling.
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
# Zero temperature means greedy decoding. Avoid division by zero.
nonzero_t = torch.where(t != 0, t, 1.0)
logits = logits / nonzero_t.unsqueeze(dim=1)
if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
# Random sampling.
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
sampled_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
if num_samples == 1:
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
next_token_ids = torch.where(t != 0, sampled_token_ids,
argmax_token_ids)
return next_token_ids
def _get_padded_prefill_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_batch_size(batch_size: int) -> int:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if batch_size <= 8:
return 8
else:
return ((batch_size + 15) // 16) * 16
def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
logits_sorted = torch.sort(logits, dim=-1, descending=True).values
sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
return logits
def _make_decode_output(
next_token_ids: List[int],
seq_groups: List[List[int]],
) -> SamplerOutput:
zero_logprob = Logprob(0.0)
sampler_outputs = []
batch_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group
seq_outputs = []
for seq_id in seq_ids:
next_token_id = next_token_ids[batch_idx]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
sampler_outputs.append(CompletionSequenceGroupOutput(
seq_outputs, None))
return SamplerOutput(sampler_outputs)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import List, Optional, Tuple, Union
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoRANotSupportedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
assert self.device_config.device_type == "tpu"
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
self.model_runner: TPUModelRunner = TPUModelRunner(
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
if self.model_config.seed is None:
self.model_config.seed = 0
if vllm_config.lora_config is not None:
raise NotImplementedError(
"The V0 TPU backend doesn't support LoRA serving")
def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU"
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment(
world_size=self.parallel_config.world_size,
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
# Device initialization should happen after initializing the distributed
# runtime.
self.device = xm.xla_device()
self.device_config.device = self.device
# Set random seed.
set_random_seed(self.model_config.seed)
xm.set_rng_state(self.model_config.seed, self.device)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
# 30-40 graphs for decode. 128 is an arbitrary safe number.
torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size = self.parallel_config.world_size
rank = xr.global_ordinal()
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
# Consequently, changes in optimization flags, which affect compilation
# results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if envs.VLLM_XLA_CACHE_PATH:
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False)
self.profiler = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir)
self.profiler = xp.start_server(9012)
def start_profile(self):
if self.rank < 1:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
xp.start_trace(self.profile_dir)
def stop_profile(self):
if self.rank < 1:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
xp.stop_trace()
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
num_layers = self.model_config.get_num_layers(self.parallel_config)
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [(torch.tensor([], dtype=torch.float32,
device=self.device),
torch.tensor([], dtype=torch.float32,
device=self.device))
for _ in range(num_layers)]
bind_kv_cache(self.compilation_config.static_forward_context,
[kv_caches])
self.model_runner._dummy_run(
batch_size=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
kv_caches=kv_caches,
exec_mode=ExecutionMode.PREFILL,
)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
# Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(total_memory_size *
self.cache_config.gpu_memory_utilization)
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
dtype_bytes = get_dtype_size(self.cache_dtype)
block_size_bytes = (dtype_bytes * self.cache_config.block_size *
num_layers * 2 * head_size * num_kv_heads)
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
# Calculate the CPU KV cache size based on the config.
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
block_size_bytes)
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
return num_tpu_blocks, num_cpu_blocks
def initialize_cache(
self,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.block_size = self.cache_config.block_size
dtype = self.cache_dtype
num_layers = self.model_config.get_num_layers(self.parallel_config)
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
head_size = self.model_config.get_head_size()
self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_cpu_blocks, self.block_size, num_kv_heads, head_size)
for _ in range(num_layers):
tpu_k_cache = torch.zeros(tpu_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
cpu_k_cache = torch.zeros(cpu_cache_shape,
dtype=dtype,
device="cpu")
cpu_v_cache = torch.zeros_like(cpu_k_cache)
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
bind_kv_cache(self.compilation_config.static_forward_context,
[self.tpu_cache])
self._warmup_model()
def _warmup_model(self) -> None:
# FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
# for CUDA graphs. We should refactor this part.
if not self.model_config.enforce_eager:
# Warm up the model with all possible input shapes so that
# compilation never happens during the actual execution.
# This may take ~30 mins for the first run and ~20 mins for the
# subsequent runs.
# If `enforce_eager` is True, the ahead-of-time compilation is
# skipped and the compilation happens during the actual execution,
# which is bad for performance but useful for development.
self.model_runner.warmup_model(self.tpu_cache)
def get_cache_block_size_bytes(self) -> int:
head_size = self.model_config.get_head_size()
num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
num_layers = self.model_config.get_num_layers(self.parallel_config)
key_cache_block = self.cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(self.cache_dtype)
return dtype_size * total
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return [self.tpu_cache]
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
def prepare_worker_input(
self,
execute_model_req: ExecuteModelRequest,
) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
blocks_to_swap_in = _make_src_to_dst(
execute_model_req.blocks_to_swap_in, "cpu", self.device)
blocks_to_swap_out = _make_src_to_dst(
execute_model_req.blocks_to_swap_out, self.device, "cpu")
blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
self.device, self.device)
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
assert virtual_engine == 0
attn_backend = self.model_runner.attn_backend
num_layers = self.model_config.get_num_layers(self.parallel_config)
# Issue cache operations.
if worker_input.blocks_to_swap_in is not None:
src_indices, dst_indices = worker_input.blocks_to_swap_in
if src_indices.numel() > 0:
# Swap from CPU to TPU.
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
k = cpu_k_cache[:, src_indices].to(self.device)
v = cpu_v_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
if worker_input.blocks_to_swap_out is not None:
src_indices, dst_indices = worker_input.blocks_to_swap_out
if src_indices.numel() > 0:
# Swap from TPU to CPU.
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
if worker_input.blocks_to_copy is not None:
src_indices, dst_indices = worker_input.blocks_to_copy
if src_indices.numel() > 0:
attn_backend.copy_blocks(self.tpu_cache,
(src_indices, dst_indices))
def _make_src_to_dst(
mapping: List[Tuple[int, int]],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
if not mapping:
return None
src_indices = [i for i, _ in mapping]
dst_indices = [i for _, i in mapping]
src_indices = torch.tensor(src_indices,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_indices,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices
@torch.compile(backend="openxla")
def _insert_kv(
k: torch.Tensor,
v: torch.Tensor,
indices: torch.Tensor,
tpu_k_cache: torch.Tensor,
tpu_v_cache: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
tpu_k_cache[:, indices] = k
tpu_v_cache[:, indices] = v
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import time
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, TypeVar)
import torch
import torch.nn as nn
from vllm.attention import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import DeviceMemoryProfiler, GiB_bytes, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU")
@dataclass(frozen=True)
class ModelInputForXPU(ModelRunnerInputBase):
"""
Used by the NeuronModelRunner.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type[TModelInputForXPU],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> TModelInputForXPU:
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
@dataclass(frozen=True)
class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForXPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
def __init__(self,
runner: "XPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.runner = runner
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
def build(self) -> ModelInputForXPU:
is_prompt = self.seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs) = self._prepare_prompt(
self.seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(
self.seq_group_metadata_list)
seq_lens = None
multi_modal_kwargs = None
return self.model_input_cls(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
multi_modal_kwargs=multi_modal_kwargs,
seq_lens=seq_lens,
query_lens=seq_lens,
)
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
computed_len = seq_data.get_num_computed_tokens()
seq_len = len(prompt_tokens)
seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
positions_range = range(computed_len, seq_len)
input_positions.extend(list(positions_range))
if seq_group_metadata.multi_modal_data:
# NOTE: mm_kwargs only includes the subset of multi-modal items
# that intersect with the current prefill positions.
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
block_number = block_table[i //
self.block_size] # type: ignore
block_offset = i % self.block_size # type: ignore
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device) # type: ignore
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device) # type: ignore
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
max_seqlen = max(seq_lens)
tmp = [0]
tmp.extend(seq_lens)
seqlen = torch.tensor(tmp)
seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
seq_lens=seq_lens,
seqlen_q=seqlen_q,
max_seqlen=max_seqlen,
seq_lens_tensor=torch.tensor([]),
max_decode_seq_len=0,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token)
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append(position)
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
max_decode_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
block_tables = make_tensor_with_pad(
block_tables,
pad=0,
dtype=torch.int,
device=self.device,
)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=seq_lens,
seqlen_q=torch.tensor([]),
max_seqlen=0,
seq_lens_tensor=seq_lens_tensor,
max_decode_seq_len=max_decode_seq_len,
num_prefill_tokens=0,
num_decode_tokens=len(input_tokens),
num_prefills=0,
block_tables=block_tables,
)
return (
input_tokens,
input_positions,
attn_metadata,
)
class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = (
ModelInputForXPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
model_config = self.model_config
cache_config = self.cache_config
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Multi-modal data support
self.input_registry = input_registry
self.mm_registry = mm_registry
# Lazy initialization.
self.model: nn.Module # Set after init_Model
self.sampler = get_sampler()
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None:
with DeviceMemoryProfiler() as m:
self.model = get_model(vllm_config=self.vllm_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GiB",
self.model_memory_usage / GiB_bytes)
def get_model(self) -> nn.Module:
return self.model
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
# Additional GPU memory may be needed for multi-modal encoding, which
# needs to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs,
max_num_batched_tokens // max_mm_tokens)
if max_num_seqs < 1:
expr = (f"min({max_num_seqs_orig}, "
f"{max_num_batched_tokens} // {max_mm_tokens})")
logger.warning(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=None,
multi_modal_data=dummy_data.multi_modal_data,
multi_modal_placeholders=dummy_data.multi_modal_placeholders)
seqs.append(seq)
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, None, intermediate_tensors)
torch.xpu.synchronize()
return
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForXPUWithSamplingMetadata:
return (
ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForXPUWithSamplingMetadata:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
builder = self.builder
builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
return builder.build() # type: ignore
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForXPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators,
cache=self.sampling_metadata_cache)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
virtual_engine=virtual_engine)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForXPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"XPUModelRunner does not support multi-step execution.")
model_executable = self.model
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start_time = time.time()
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end_time = time.time()
# Compute the logits.
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_time = (model_forward_end_time -
model_forward_start_time)
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = model_forward_time
return [output]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A XPU worker class."""
import gc
import os
from typing import List, Optional, Tuple
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
import torch.distributed
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.xpu_model_runner import XPUModelRunner
logger = init_logger(__name__)
class XPUWorker(LoRANotSupportedWorkerBase, Worker):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single XPU device. The worker is
responsible for maintaining the KV cache and executing the model on the
XPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
device_config = self.device_config
parallel_config = self.parallel_config
assert device_config.device_type == "xpu"
assert current_platform.is_xpu()
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
self.model_runner = XPUModelRunner( # type: ignore
vllm_config=vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine]
self.gpu_cache: Optional[List[List[torch.Tensor]]]
def init_device(self) -> None:
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
):
self.device = torch.device(f"xpu:{self.local_rank}")
torch.xpu.set_device(self.device)
torch.xpu.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties(
self.local_rank).total_memory
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
self.init_worker_distributed_environment()
# Initialize the model.
set_random_seed(self.model_config.seed)
# keep this method for `empty_cache` and `synchronize` api
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.xpu.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.xpu.synchronize()
used_memory = torch.xpu.memory_allocated()
total_gpu_memory = torch.xpu.get_device_properties(
self.local_rank).total_memory
free_gpu_memory = total_gpu_memory - used_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes()
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
gc.collect()
torch.xpu.empty_cache()
return num_gpu_blocks, num_cpu_blocks
def _warm_up_model(self) -> None:
# IPEX don't support capture graph yet
pass
def init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch "
"world size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
# use sockets as default Level zero IPC exchange backend. By
# default oneccl will use `drmfd` as mechanism which need extra
# dependency (libdrm and drm headers) on your system.
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
str(parallel_config.world_size))
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank)
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=self.local_rank,
backend="ccl")
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu())
if parallel_config.pipeline_parallel_size > 1:
# Add pp group init to avoid
# p2p communication as the first call
get_pp_group().all_reduce(torch.zeros(1).xpu())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment