Commit 3de379de authored by zhuwenwen's avatar zhuwenwen
Browse files

update unused code

parent 5ad884ee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import weakref
from typing import List, Optional, Set, Tuple, Dict
import torch
import torch.nn.functional as F
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import DelegateWorkerBase
from vllm.spec_decode.tree_style_proposer import TreeStyleProposer
from vllm.distributed import broadcast_tensor_dict
from vllm.worker.worker_base import WorkerWrapperBase
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
"""Worker for Medusa.
"""
def __init__(self, *args, **kwargs):
# skip lora config in medusa
DelegateWorkerBase.__init__(self, *args, **kwargs)
# Lazy initialization list.
self._proposer: SpeculativeProposer
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def init_device(self):
self.worker.init_device()
def load_model(self):
super().load_model()
# get medusa choices and generate medusa_buffers
self.medusa_buffers = None
if self.tree_decoding and hasattr(self.model_runner.model, 'medusa_choices'):
self.medusa_choices = self.model_runner.model.medusa_choices
if self.medusa_choices is not None:
self.medusa_buffers = self.generate_medusa_buffers(
self.medusa_choices, device=self.device
)
if self.medusa_buffers is None:
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
else:
self._proposer = TreeStyleProposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
self.medusa_buffers,
max_proposal_len=self.max_model_len,
)
def set_include_gpu_probs_tensor(self):
pass
def set_should_modify_greedy_probs_inplace(self):
pass
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Dict[str, torch.Tensor]:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
sample_indices_list = []
for seq_group in sampling_metadata.seq_groups:
sample_indices_list.append(seq_group.sample_indices)
previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
previous_logits = execute_model_req.previous_logits.logits if \
execute_model_req.previous_logits is not None else None
tensor_dict = {
"previous_hidden_states": previous_hidden_states,
"previous_logits": previous_logits,
"sample_indices_list": sample_indices_list,
"seq_lens": seq_lens
}
if self.do_metadata_broadcast:
broadcast_tensor_dict(tensor_dict, src=0)
return tensor_dict
def _get_worker_input_from_broadcast(
self
) -> Optional[Dict[str, torch.Tensor]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
return broadcast_data
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
# Unused parameter.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For medusa worker, this indicator shall be False.
"""
self._raise_if_unsupported(execute_model_req)
if self.is_driver_worker:
tensor_dict = self._get_driver_input_and_broadcast(execute_model_req)
else:
tensor_dict = self._get_worker_input_from_broadcast()
if tensor_dict is None:
raise ValueError("Can not get inputs of medusa worker!!!")
model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=tensor_dict["previous_hidden_states"],
sample_indices_list=tensor_dict["sample_indices_list"],
previous_logits=tensor_dict["previous_logits"],
medusa_buffers=self.medusa_buffers)
# create tree attn masks
if self.is_driver_worker and self.medusa_buffers is not None:
seq_lens = tensor_dict["seq_lens"]
max_context_len = max(seq_lens)
for sampler_output, seq_len in zip(model_outputs, seq_lens):
context_len = seq_len
attn_masks = self.medusa_buffers['tree_attn_masks']
left_mask = torch.ones(attn_masks.shape[0], context_len,
dtype=attn_masks.dtype,
device=attn_masks.device)
attn_masks = torch.cat([left_mask, attn_masks], dim=-1)
right_pad = max_context_len - context_len
if right_pad > 0:
attn_masks = F.pad(attn_masks, (0, right_pad), "constant", 0)
sampler_output.tree_attn_masks = attn_masks
return model_outputs, False
def _prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[List[int], List[int]]:
if not seq_group_metadata_list:
return [], []
seq_lens: List[int] = []
query_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
is_prompt = seq_group_metadata.is_prompt
for seq_data in seq_group_metadata.seq_data.values():
seq_data_len = seq_data.get_len()
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
seq_len = min(
seq_data_len,
context_len + seq_group_metadata.token_chunk_size)
seq_lens.append(seq_len)
query_lens.append(seq_len - context_len)
else:
# first step of tree decoding need to ignore first token
if self.medusa_buffers is not None and seq_data.get_first_step_flag():
seq_data_len -= 1
seq_lens.append(seq_data_len)
query_lens.append(1)
return seq_lens, query_lens
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_spec_proposals(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
def _raise_if_unsupported(
self,
execute_model_req: ExecuteModelRequest,
) -> None:
"""MedusaWorker does not yet implement support for cache swap
operations or beam search.
"""
if execute_model_req is None:
return None
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"MedusaWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MedusaWorker does not support beam search.")
def pad_path(self, path, length, pad_value=-2):
"""
Pad the given path list with a specific value up to a specified length.
Parameters:
- path (list): The original list that needs padding.
- length (int): The desired length of the padded list.
- pad_value (optional, default=-2): The value to use for padding.
Returns:
- list: A new list based on the original path but padded to the desired length.
Example:
>>> pad_path([1,2,3], 5)
[1, 2, 3, -2, -2]
Note:
If the given path is already longer than the specified length,
then no padding occurs, and the original path is returned.
"""
# Calculate the number of padding values needed by subtracting the length
# of the path from the desired length.
# Append the padding values to the original path and return the new list.
return path + [pad_value] * (length - len(path))
def generate_medusa_buffers(self, medusa_choices, device="cuda"):
"""
Generate buffers for the Medusa structure based on the provided choices.
Parameters:
- medusa_choices (list): A nested list representing tree in the Medusa structure.
- device (str): Device to which the tensors should be moved. Default is "cuda".
Returns:
- dict: A dictionary containing buffers related to the Medusa structure.
"""
# Sort the medusa_choices based on their lengths and then their values
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
medusa_len = len(sorted_medusa_choices) + 1
# Initialize depth_counts to keep track of how many choices have a particular depth
depth_counts = []
prev_depth = 0
for path in sorted_medusa_choices:
depth = len(path)
if depth != prev_depth:
depth_counts.append(0)
depth_counts[depth - 1] += 1
prev_depth = depth
# Create the attention mask for Medusa
medusa_attn_mask = torch.eye(medusa_len, medusa_len)
medusa_attn_mask[:, 0] = 1
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
# retrieve ancestor position
if len(cur_medusa_choice) == 1:
continue
ancestor_idx = []
for c in range(len(cur_medusa_choice) - 1):
ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
medusa_attn_mask[j + start + 1, ancestor_idx] = 1
start += depth_counts[i]
# Generate tree indices for the Medusa structure
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
medusa_tree_indices[0] = 0
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
start += depth_counts[i]
# Generate position IDs for the Medusa structure
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
start = 0
for i in range(len(depth_counts)):
medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
start += depth_counts[i]
# Generate retrieval indices for Medusa structure verification
retrieve_indices_nest = []
retrieve_paths = []
for i in range(len(sorted_medusa_choices)):
cur_medusa_choice = sorted_medusa_choices[-i-1]
retrieve_indice = []
if cur_medusa_choice in retrieve_paths:
continue
else:
for c in range(len(cur_medusa_choice)):
retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
retrieve_paths.append(cur_medusa_choice[:c+1])
retrieve_indices_nest.append(retrieve_indice)
max_length = max([len(x) for x in retrieve_indices_nest])
retrieve_indices = [self.pad_path(path, max_length) for path in retrieve_indices_nest]
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
retrieve_indices = retrieve_indices + 1
retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
# Aggregate the generated buffers into a dictionary
medusa_buffers = {
"tree_attn_masks": medusa_attn_mask.int(),
"tree_indices": medusa_tree_indices,
"tree_position_ids": medusa_position_ids,
"retrieve_indices": retrieve_indices,
}
# Move the tensors in the dictionary to the specified device
medusa_buffers = {
k: v.clone().to(device)
if isinstance(v, torch.Tensor)
else torch.tensor(v, device=device)
for k, v in medusa_buffers.items()
}
return medusa_buffers
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import Callable, Optional, Union
import msgspec
import torch
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
class SpecDecodeWorkerMetrics(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""Dataclass holding metrics emitted from the spec decode worker.
"""
# The empirical acceptance rate of the proposal method on a per-token basis.
# This is useful for evaluating how well the proposal method aligns with the
# scoring method.
draft_acceptance_rate: float
# The empirical efficiency, measured as the number of tokens emitted by the
# system divided by the number of tokens that could be emitted by the system
# if the proposal method were perfect.
system_efficiency: float
# The number of speculative tokens produced by the proposal method.
draft_tokens: int
# The number of tokens emitted by the entire system.
emitted_tokens: int
# The number of tokens accepted by the scoring model and verification
# routine, e.g. Llama2-70B and lossless rejection sampling.
#
# NOTE: Any token accepted by the verification routine is considered
# accepted (regardless of if the speculative prefix is also accepted). The
# user will usually see less accepted tokens. This metric is helpful when
# evaluating alignment of the proposal method with the scoring model.
accepted_tokens: int
# The number of speculative tokens per sequence.
num_spec_tokens: int
Timer = Callable[[], float]
class AsyncMetricsCollector:
"""Class which copies rejection/typical-acceptance sampler metrics
from the device to CPU on a non-default Torch stream.
"""
def __init__(self,
spec_decode_sampler: SpecDecodeBaseSampler,
timer: Optional[Timer] = None,
collect_interval_s: float = 5.0):
self.spec_decode_sampler = spec_decode_sampler
self._timer = time.time if timer is None else timer
self._rank: Optional[int] = None
# We don't have a device set yet.
self._copy_stream: Optional[torch.cuda.Stream] = None
self._in_flight_copy: Optional[torch.cuda.Event] = None
self._aggregate_num_draft_tokens = 0
self._rejsample_metrics_collect_interval_s = collect_interval_s
self._last_metrics_collect_time = self._timer()
def init_gpu_tensors(self, rank: int) -> None:
self._rank = rank
self._copy_stream = torch.cuda.Stream()
def init_tensors(self,
rank: int,
device_type: Union[torch.device, str] = 'cuda') -> None:
self._rank = rank
if isinstance(device_type, torch.device):
torch.cuda.set_device(device_type)
device_type = device_type.type
# stream = current_platform.Stream
# if stream is not None:
# self._copy_stream = stream()
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()
pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# Skip for any platform that doesn't have device Event
# if current_platform.Event is None:
# return None
# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
ready_event = self._in_flight_copy
self._in_flight_copy = None
return self._collect_rejsample_metrics(k, ready_event)
# Otherwise, check if we should start a new copy.
if self._should_collect_rejsample_metrics(self._timer()):
assert self._in_flight_copy is None
self._in_flight_copy = self._copy_rejsample_metrics_async()
return None
def _should_collect_rejsample_metrics(self, now: float) -> bool:
"""Return whether or not this iteration should print sampling
metrics.
"""
if self._rank != 0:
return False
return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.
Returns a device event recording when the copy is complete.
"""
assert self._copy_stream is not None
self._copy_stream.wait_stream(current_platform.current_stream())
with current_platform.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_(
self.spec_decode_sampler.num_accepted_tokens,
non_blocking=True)
self._aggregate_num_emitted_tokens.copy_(
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self._aggregate_num_draft_tokens = (
self.spec_decode_sampler.num_draft_tokens)
aggregate_metrics_ready = current_platform.Event()
aggregate_metrics_ready.record(self._copy_stream)
return aggregate_metrics_ready
def _collect_rejsample_metrics(
self, k: int,
ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
"""Create metrics object from statistics copied asynchronously.
Args:
k: int. The number of speculative tokens; used to determine system
efficiency.
ready_event: torch.cuda.Event. The CUDA event recording when the
async GPU->CPU copy is complete.
"""
ready_event.synchronize()
# update time of last collection
self._last_metrics_collect_time = self._timer()
accepted_tokens = self._aggregate_num_accepted_tokens.item()
emitted_tokens = self._aggregate_num_emitted_tokens.item()
draft_tokens = self._aggregate_num_draft_tokens
max_num_emitted_tokens = self.get_max_num_emitted_tokens(
draft_tokens, k)
if draft_tokens > 0:
draft_acceptance_rate = accepted_tokens / draft_tokens
else:
draft_acceptance_rate = float("nan")
if max_num_emitted_tokens > 0:
system_efficiency = emitted_tokens / max_num_emitted_tokens
else:
system_efficiency = float("nan")
return SpecDecodeWorkerMetrics(
num_spec_tokens=k,
draft_acceptance_rate=draft_acceptance_rate,
system_efficiency=system_efficiency,
accepted_tokens=accepted_tokens,
draft_tokens=draft_tokens,
emitted_tokens=emitted_tokens,
)
@staticmethod
def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
"""Calculate the number of emitted tokens, assuming all tokens are
accepted.
This is equal to the number of sequences that have been speculated on,
times (speculation len + 1). The +1 comes from the bonus token.
"""
# Determine the number of sequences that have been speculated on. Since
# the batch size can be variable, we divide by k.
assert draft_tokens % k == 0
total_num_spec_seqs = draft_tokens // k
# A single sequence may emit k accepted tokens and one bonus token in
# the best case.
num_emitted_per_seq_if_all_accepted = k + 1
# The max num of emitted tokens is the number of speculated sequences
# times the max emitted per seq.
return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List, Optional, Set, Tuple, Dict
import torch
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.distributed import broadcast_tensor_dict
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
"""Worker for MLPSpeculator models.
Not currently compatible with LoRA or chunked prefill.
"""
def _get_driver_input_and_broadcast(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
index: int,
last_tokens: Optional[torch.Tensor]=None,
previous_hidden_states: Optional[torch.Tensor]=None,
sampling_metadata: Optional[SamplingMetadata]=None
) -> Dict[str, torch.Tensor]:
if sampling_metadata is None and execute_model_req is not None:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
(input_tokens, seq_lens,
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
# b x 1
last_tokens = input_tokens.unsqueeze(1)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)
tensor_dict = {
"input_tokens": last_tokens,
"previous_hidden_states": previous_hidden_states,
"sample_len": sample_len,
"head_index": index
}
if self.do_metadata_broadcast:
broadcast_tensor_dict(tensor_dict, src=0)
return tensor_dict, sampling_metadata
def _get_worker_input_from_broadcast(
self
) -> Optional[Dict[str, torch.Tensor]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
return broadcast_data
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
# Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
# therefore does not need this parameter.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For mlp spec worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
model_outputs = []
last_tokens = None
previous_hidden_states = None
sampling_metadata = None
for index in range(sample_len):
if self.is_driver_worker:
tensor_dict, sampling_metadata = self._get_driver_input_and_broadcast(execute_model_req,
sample_len,
index,
last_tokens,
previous_hidden_states,
sampling_metadata)
assert sampling_metadata is not None
output, previous_hidden_states = self.model_runner.model.generate_proposals(
input_ids=tensor_dict["input_tokens"],
previous_hidden_states=tensor_dict["previous_hidden_states"],
num_predict_tokens=tensor_dict["sample_len"],
sampling_metadata=sampling_metadata,
head_index=index)
last_tokens = output.sampled_token_ids
model_outputs.append(output)
else:
tensor_dict = self._get_worker_input_from_broadcast()
if tensor_dict is None:
raise ValueError("Can not get inputs of mlp_speculator worker!!!")
self.model_runner.model.generate_proposals(
input_ids=tensor_dict["input_tokens"],
previous_hidden_states=tensor_dict["previous_hidden_states"],
num_predict_tokens=tensor_dict["sample_len"],
sampling_metadata=None,
head_index=tensor_dict["head_index"])
if self.is_driver_worker:
assert len(model_outputs) == sample_len
return model_outputs, True
def _prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, List[int], List[int]]:
if not seq_group_metadata_list:
return torch.empty(0, device=self.device), [], []
input_tokens: List[int] = []
seq_lens: List[int] = []
query_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
is_prompt = seq_group_metadata.is_prompt
for seq_data in seq_group_metadata.seq_data.values():
seq_data_len = seq_data.get_len()
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
seq_len = min(
seq_data_len,
context_len + seq_group_metadata.token_chunk_size)
tokens = seq_data.get_token_ids()[context_len:seq_len]
seq_lens.append(seq_len)
input_tokens.extend(tokens)
query_lens.append(seq_len - context_len)
else:
seq_lens.append(seq_data_len)
input_tokens.append(seq_data.get_last_token_id())
query_lens.append(1)
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
return input_tokens_tensor, seq_lens, query_lens
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import weakref
from typing import Dict, List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import DelegateWorkerBase
class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
by invoking the scheduler less.
The MultiStepWorker does not support cache swap operations, or beam search.
Cache swap operations do not require large modifications. On the other hand,
beam search requires memory allocations during sequence forks and thus
requires more thought for MultiStepWorker support.
"""
def __init__(self, *args, **kwargs):
DelegateWorkerBase.__init__(self, *args, **kwargs)
# Lazy initialization list.
self._proposer: SpeculativeProposer
def init_device(self) -> None:
self.worker.init_device()
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.sampler.include_gpu_probs_tensor = True
if hasattr(self.model_runner.model, "sampler"):
(self.model_runner.model.sampler.include_gpu_probs_tensor) = True
def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.sampler.should_modify_greedy_probs_inplace = True
if hasattr(self.model_runner.model, "sampler"):
(self.model_runner.model.sampler.should_modify_greedy_probs_inplace
) = True
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
if expanded_request.previous_hidden_states is not None:
self.worker.model_runner.return_hidden_states = True
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._maybe_update_previous_hidden_states(
model_output, expanded_request)
self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
@staticmethod
def _maybe_update_previous_hidden_states(
model_output: SamplerOutput,
expanded_request: ExecuteModelRequest) -> None:
"""
Updates the previous hidden states in an expanded request
in-place with the hidden states from the model output.
"""
if expanded_request.previous_hidden_states is not None:
expanded_request.previous_hidden_states = HiddenStates(
model_output.hidden_states,
expanded_request.seq_group_metadata_list)
@staticmethod
def _expand_execute_model_request(
execute_model_req: ExecuteModelRequest,
seq_with_bonus_token_in_last_step: set,
) -> Tuple[ExecuteModelRequest, List[int]]:
"""
Expands the execute model request based on sequences with bonus
tokens.
For each sequence with a bonus token, this method creates a new
sequence without the bonus token and adds it to the execute model
request. The original sequence groups are also retained. The indices
of the original sequence groups are returned for further processing.
Args:
execute_model_req (ExecuteModelRequest): The original execute
model request.
seq_with_bonus_token_in_last_step (set): Set of sequence IDs that
contain bonus tokens.
Returns:
Tuple[ExecuteModelRequest, List[int]]: The updated execute model
request with expanded sequences and a list of indices corresponding
to the original sequence groups.
"""
updated_seq_group_metadata_list: List[SequenceGroupMetadata] = []
updated_execute_model_req = execute_model_req.clone(
updated_seq_group_metadata_list)
indices_of_original_sequence_groups = []
for seq_group in execute_model_req.seq_group_metadata_list:
seq_group_has_bonus_tokens = False
for seq_id, _ in seq_group.seq_data.items():
# Identify sequences with bonus tokens in the sequence group.
if seq_id in seq_with_bonus_token_in_last_step:
seq_group_has_bonus_tokens = True
break
if seq_group_has_bonus_tokens:
#Create new sequences without the last bonus token. These new
# sequence have the same sequence id as the original sequence.
# We create a new sequence group and add them there.
updated_seq_group_without_bonus_token = \
MultiStepWorker._copy_seq_metadata_excluding_last_token(
seq_group, seq_with_bonus_token_in_last_step)
updated_seq_group_metadata_list.append(
updated_seq_group_without_bonus_token)
# Add the original sequence group.
updated_seq_group_metadata_list.append(
MultiStepWorker._shallow_copy_seq_group_metadata(seq_group))
# Record the index of the original sequence group.
indices_of_original_sequence_groups.append(
len(updated_seq_group_metadata_list) - 1)
updated_execute_model_req.seq_group_metadata_list =\
updated_seq_group_metadata_list
if isinstance(updated_execute_model_req.previous_hidden_states,
HiddenStates):
updated_execute_model_req.previous_hidden_states\
.expand_with_bonus_tokens(seq_with_bonus_token_in_last_step)
return updated_execute_model_req, indices_of_original_sequence_groups
@staticmethod
def _filter_model_output(
expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
return [
SamplerOutput(
outputs=[
expanded_batch_output.outputs[i]
for i in output_indices_to_retain
] if len(expanded_batch_output.outputs) > 0 else [],
sampled_token_probs=(
expanded_batch_output.
sampled_token_probs[output_indices_to_retain]
if expanded_batch_output.sampled_token_probs is not None
else None),
logprobs=(
expanded_batch_output.logprobs[output_indices_to_retain]
if expanded_batch_output.logprobs is not None else None),
sampled_token_ids=(expanded_batch_output.
sampled_token_ids[output_indices_to_retain]
if expanded_batch_output.sampled_token_ids
is not None else None))
for expanded_batch_output in expanded_batch_outputs
]
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: set,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_spec_proposals(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
@staticmethod
def _append_new_tokens(
model_output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
indices_of_seq_with_bonus_tokens: List[int]) -> None:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
"""
count = 0
for index, (seq_group_metadata, sequence_group_outputs) in enumerate(
zip(seq_group_metadata_list, model_output)):
seq_group_metadata.is_prompt = False
for seq_output in sequence_group_outputs.samples:
# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
# Determine the actual token ID to be generated,
# considering bonus tokens
if index != indices_of_seq_with_bonus_tokens[count]:
bonus_seq_metadata = seq_group_metadata_list[
indices_of_seq_with_bonus_tokens[count]]
_, bonus_token_seq_data = next(
iter(bonus_seq_metadata.seq_data.items()))
token_id = bonus_token_seq_data.output_token_ids[-1]
else:
count += 1
seq.append_token_id(token_id, token_logprob.logprob,
seq_output.output_embed)
seq.update_num_computed_tokens(1)
@staticmethod
def _shallow_copy_seq_group_metadata(
seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata:
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
Helpful when the vLLM scheduler runs in the same process as the worker.
The alternative is deep-copying (or other form of deep copy); this has
performance downsides.
"""
# Shallow-copy the SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
# We must shallow-copy seq_group_metadata as is_prompt could change.
new_seq_group_metadata = copy.copy(seq_group_metadata)
# We must shallow-copy seq_data as we will append token ids
new_seq_data: Dict[int, SequenceData] = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[seq_id].output_token_ids =\
old_seq_data.output_token_ids[:]
new_seq_group_metadata.seq_data = new_seq_data
return new_seq_group_metadata
@staticmethod
def _copy_seq_metadata_excluding_last_token(
seq_group_metadata: SequenceGroupMetadata,
seq_ids_to_copy: Set[int],
) -> SequenceGroupMetadata:
"""
Creates a shallow copy of the given SequenceGroupMetadata, retaining
only the sequence IDs specified in seq_ids_to_copy. For each of these
sequence IDs, all output_token_ids except the last one are copied.
Sequence IDs not in seq_ids_to_copy are excluded from the copy.
Parameters:
seq_group_metadata (SequenceGroupMetadata): The original sequence
group metadata.
seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
copy.
Returns:
SequenceGroupMetadata: A shallow copy of the sequence group metadata
with the specified modifications.
"""
# Shallow-copy the SequenceGroupMetadata.
new_seq_group_metadata = copy.copy(seq_group_metadata)
# Shallow-copy seq_data and modify the output_token_ids.
new_seq_data: Dict[int, SequenceData] = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
if (seq_id in seq_ids_to_copy):
new_seq_data[seq_id] = copy.copy(old_seq_data)
# Copy all the output token ids except the last.
# Also reduce num_computed_tokens by 1 since we are not
# including the last output token.
# NOTE: num_computed_tokens is not directly used by the
# speculative decoding workers, as it is only relevant for
# chunked prefill, which is disabled for speculative decoding.
# However, to maintain consistency in num_computed_tokens,
# we update it here.
new_seq_data[seq_id].output_token_ids =\
old_seq_data.output_token_ids[:-1]
new_seq_data[seq_id].update_num_computed_tokens(-1)
new_seq_group_metadata.seq_data = new_seq_data
return new_seq_group_metadata
def _assert_enough_kv_space(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
num_steps: int) -> None:
"""Assert there are enough physical blocks per sequence to store the
current KV plus additional KV from num_steps tokens.
"""
assert self.model_runner.block_size is not None
for seq_group_metadata in seq_group_metadata_list:
# Only one seq_id is guaranteed because there is no beam search.
seq_id = list(seq_group_metadata.seq_data.keys())[0]
seq = seq_group_metadata.seq_data[seq_id]
# After num_steps, the seq len will be the current seq len
# plus one token per step.
final_seq_len = seq.get_len() + num_steps
# We will have final_seq_len - 1 KV because vLLM saves KV for a
# token in the iteration after the token was generated.
required_num_kv_slots = final_seq_len - 1
# The allocated number of kv slots is the number of allocated blocks
# times the number of slots of block.
number_physical_blocks = len(
seq_group_metadata.block_tables[seq_id])
allocated_kv_slots = (number_physical_blocks *
self.model_runner.block_size)
if required_num_kv_slots > allocated_kv_slots:
request_id = seq_group_metadata.request_id
raise ValueError(
"The worker attempted to run "
f"{num_steps} times but found insufficient KV space for "
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
f"{required_num_kv_slots=}).")
def _raise_if_unsupported(
self,
execute_model_req: ExecuteModelRequest,
) -> None:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if execute_model_req is None:
return None
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"MultiStepWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")
def maybe_load_lm_head_weight(
self,
lm_head_weight: torch.Tensor,
) -> None:
weight_loader = getattr(
self.worker.model_runner.model_runner.model.lm_head.weight,
"weight_loader", default_weight_loader)
weight_loader(
self.worker.model_runner.model_runner.model.lm_head.weight,
lm_head_weight)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -161,7 +161,7 @@ curr_o, curr_lse = scaled_dot_product_attention( ...@@ -161,7 +161,7 @@ curr_o, curr_lse = scaled_dot_product_attention(
for chunk_idx in range(cdiv(C, MCC)): for chunk_idx in range(cdiv(C, MCC)):
chunk_start = chunk_idx * MCC chunk_start = chunk_idx * MCC
chunk_end = min(chunk_start + MCC, C) chunk_end = min(chunk_start + MCC, C)
Sc = chunk_end - chunk_start_table Sc = chunk_end - chunk_start
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
......
...@@ -45,10 +45,8 @@ class CommonAttentionMetadata: ...@@ -45,10 +45,8 @@ class CommonAttentionMetadata:
seq_lens_cpu: torch.Tensor seq_lens_cpu: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens """(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens""" and newly scheduled tokens"""
num_computed_tokens_cpu: torch.Tensor num_computed_tokens_cpu: torch.Tensor
"""(batch_size,), the number of computed tokens for each request""" """(batch_size,), the number of computed tokens for each request"""
num_reqs: int num_reqs: int
"""Number of requests""" """Number of requests"""
num_actual_tokens: int num_actual_tokens: int
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment