Commit 3de379de authored by zhuwenwen's avatar zhuwenwen
Browse files

update unused code

parent 5ad884ee
......@@ -78,7 +78,6 @@ async def serve_http(app: FastAPI,
"port %s is used by process %s launched with command:\n%s",
port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
finally:
watchdog_task.cancel()
......
......@@ -81,7 +81,6 @@ if TYPE_CHECKING:
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_TREE_DECODING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True
......@@ -154,6 +153,7 @@ if TYPE_CHECKING:
VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
VLLM_TREE_DECODING: bool = False
VLLM_SPEC_DECODE_EAGER: bool = False
VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False
VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16
......
......@@ -345,7 +345,6 @@ class LoRAModelManager(AdapterModelManager):
max_loras=self.lora_config.max_loras)
super().__init__(model)
self.supported_lora_modules = get_supported_lora_modules(self.model)
assert self.supported_lora_modules, "No supported LoRA modules found in"
f" {self.model.__class__.__name__}."
......
from typing import Optional, Union
import torch
import triton
import triton.language as tl
from vllm.utils import is_hip
def seeded_uniform(
*size,
seeds: torch.Tensor,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
pin_memory: Optional[bool] = False,
) -> torch.Tensor:
"""Similar to torch.rand, but allows for seeds to be set per row.
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
If it is 3d, the additional seeds needed will be derived automatically
in a deterministic fashion:
[
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
]
"""
n_dims = len(size)
if n_dims > 3:
raise ValueError("seeded_uniform only supports up to 3D tensors")
if out is None:
out = torch.empty(*size,
dtype=dtype,
device=device,
pin_memory=pin_memory)
elif out.shape != size:
raise ValueError("shape of out and size must be the same")
if n_dims == 3:
n_rows, n_3d, n_cols = out.shape
stride_row = out.stride(0)
stride_3d = out.stride(1)
elif n_dims == 2:
n_rows, n_cols = out.shape
n_3d = 1
stride_row = out.stride(0)
stride_3d = 1
else:
n_cols = out.shape[0]
n_rows = 1
n_3d = 1
stride_row = 1
stride_3d = 1
if seeds.ndim != 1:
raise ValueError("seeds must be a 1D tensor")
if seeds.numel() != n_rows:
raise ValueError(
"seeds must have the same number of elements as out has rows")
# The philox PRNG Triton uses generates 4 random numbers at once.
# Therefore, the most efficient use of it is to divide the
# block size by 4, and then save the generated random numbers to
# each of the 4 slices of the tensor.
full_block_size = triton.next_power_of_2(n_cols)
philox_block_size = max(full_block_size // 4, 1)
n_slices = full_block_size // philox_block_size
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if philox_block_size >= 8192:
if is_hip():
num_warps = 16
else:
num_warps = 32
elif philox_block_size >= 4096:
if is_hip():
num_warps = 8
else:
num_warps = 16
elif philox_block_size >= 2048:
num_warps = 8
_seeded_uniform_triton[(n_rows, n_3d)](
out,
seeds,
stride_row,
stride_3d,
seeds.stride(0),
n_rows,
n_3d,
n_cols,
n_slices=n_slices,
num_warps=num_warps,
block_size=philox_block_size,
)
return out
@triton.jit
def _seeded_uniform_triton(
out_ptr: torch.Tensor,
seed_ptr: torch.Tensor,
out_row_stride: int,
out_3d_stride: int,
seed_row_stride: int,
n_rows: int,
n_3d: int,
n_cols: int,
n_slices: tl.constexpr,
block_size: tl.constexpr,
):
"""
Generate a random float32 number in [0, 1) for each element in the output
tensor. The random numbers in a row generated using the seed for that row.
Args:
out_ptr: The output tensor.
seed_ptr: The per-row seeds to use for random number generation.
out_row_stride: The stride between rows of the output tensor.
out_3d_stride: The stride between 3D slices of the output tensor.
seed_row_stride: The stride between rows of the seed tensor.
n_rows: The number of rows in the output tensor.
n_3d: The size of second dimension of the output tensor,
if output tensor is 3D.
n_cols: The number of columns in the output tensor.
n_slices: The number of philox outputs to use.
"""
tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
# Get the row index.
row_idx = tl.program_id(axis=0)
three_d_idx = tl.program_id(axis=1)
philox_offsets = tl.arange(0, block_size)
# Get the seed for the current element.
seed = tl.load(seed_ptr + row_idx * seed_row_stride)
if three_d_idx > 0:
seed ^= three_d_idx
# Generate random numbers in [0, 1).
out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
three_d_idx * out_3d_stride)
out1_offsets = philox_offsets
tl.store(output_row_start_ptr + out1_offsets,
out1,
mask=out1_offsets < n_cols)
if n_slices > 1:
out2_offsets = tl.arange(block_size, block_size * 2)
tl.store(output_row_start_ptr + out2_offsets,
out2,
mask=out2_offsets < n_cols)
if n_slices > 2:
out3_offsets = tl.arange(block_size * 2, block_size * 3)
tl.store(output_row_start_ptr + out3_offsets,
out3,
mask=out3_offsets < n_cols)
if n_slices > 3:
out4_offsets = tl.arange(block_size * 3, block_size * 4)
tl.store(output_row_start_ptr + out4_offsets,
out4,
mask=out4_offsets < n_cols)
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.model_executor.layers.ops.rand import seeded_uniform
from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import is_hip
_EPS: tl.constexpr = 1e-6
def _multi_split_sample(
probs: torch.Tensor,
seeds: torch.Tensor,
n_splits: int,
sampled_tokens_size: Tuple[int, int],
sampled_logprobs_size: Tuple[int, int],
sample_indices: torch.Tensor,
logprobs: torch.Tensor,
*,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
):
"""Sample tokens where vocab size is split into multiple parts
(too large for Triton otherwise)."""
assert seeds.ndim == 2 and seeds.shape[0] == n_splits
split_probs = probs.tensor_split(n_splits, 1)
split_logprobs = logprobs.tensor_split(n_splits, 1)
sampled_tokens_tmp = [
torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
for _ in range(n_splits)
]
sampled_logprobs_tmp = [
torch.empty(sampled_logprobs_size,
dtype=probs.dtype,
device=probs.device) for _ in range(n_splits)
]
# We are purposefuly using sampled_tokens_size as we need to always
# save modified probs in this case.
sampled_modified_probs_tmp = [
torch.empty(sampled_tokens_size,
dtype=probs.dtype,
device=probs.device) for _ in range(n_splits)
]
for i in range(n_splits):
n_samples = sample_indices.shape[0]
n_cols = split_probs[i].shape[1]
n_best = sampled_tokens_tmp[i].shape[1]
uniform_noise = seeded_uniform(n_samples,
n_best,
n_cols,
seeds=seeds[i].flatten(),
device=split_probs[i].device,
dtype=split_probs[i].dtype)
# TODO(yard1): See if we can remove the contiguous() calls.
# Will need kernel support.
_sample(
split_probs[i].contiguous(),
split_logprobs[i].contiguous(),
sample_indices,
sampled_tokens_tmp[i],
sampled_logprobs_tmp[i],
sampled_modified_probs_tmp[i],
seeds[i],
uniform_noise,
modify_greedy_probs=False,
save_logprobs=save_logprobs,
save_modified_probs=True,
)
if i > 0:
# Add offset to sampled tokens
sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
sampled_tokens = torch.stack(sampled_tokens_tmp)
sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
# Reduce the results from the splits.
sampled_modified_probs, indices = torch.max(sampled_modified_probs,
dim=0,
keepdim=True)
sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
if save_logprobs:
sampled_logprobs = torch.stack(sampled_logprobs_tmp)
sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
else:
sampled_logprobs = None
sampled_modified_probs = sampled_modified_probs.squeeze(0)
if modify_greedy_probs:
# We need to modify the greedy probs for the sampled tokens.
# We can't do this in the kernel as we need to know the
# sampled tokens.
probs.fill_(0.0)
probs.scatter_(1, sampled_tokens, 1.0)
return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
def sample(
probs: torch.Tensor,
seeds: torch.Tensor,
*,
max_best_of: int = 1,
sample_indices: Optional[torch.Tensor] = None,
logprobs: Optional[torch.Tensor] = None,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
_save_modified_probs: bool = False, # pylint: disable=invalid-name
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Sample tokens from probs. with per-sequence seeds.
Can sample from a subset of sequences through sample_indices.
Args:
probs: Probabilities to sample from.
shape = [batch_size, vocab_size]
seeds: Per-sequence seed values.
shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
max_best_of: Number of samples to generate per sequence.
Sequence seed will be incremented by 1 each time.
sample_indices: Indices of sequences to sample from.
If not provided, will sample from all sequences.
shape = [n]
logprobs: Log-probabilities of the sampled tokens.
Only used for saving the logprobs if save_logprobs is True.
shape = [batch_size, vocab_size]
modify_greedy_probs: Whether to modify the greedy probabilities
for speculative sampling (sampled token = 1.0,
everything else = 0.0).
save_logprobs: Whether to save the log-probabilities of the
sampled tokens to a tensor.
_save_modified_probs: Whether to save the modified probabilities
(including gumbel noise) of the sampled tokens to a tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
This is exposed only for testing.
Returns:
sampled_tokens: shape = [n, max_best_of]
sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
sampled_modified_probs: shape = [n, max_best_of]
if save_modified_probs else None
"""
if sample_indices is None:
sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
sampled_tokens_size = (sample_indices.size(0), max_best_of)
if save_logprobs:
if logprobs is None:
raise ValueError(
"logprobs tensor must be provided if save_logprobs is True")
sampled_logprobs_size = sampled_tokens_size
else:
# Empty tensors to invoke the kernel
sampled_logprobs_size = (0, 0)
logprobs = probs
assert logprobs is not None
if _save_modified_probs:
sampled_modified_probs_size = sampled_tokens_size
else:
# Empty tensors to invoke the kernel
sampled_modified_probs_size = (0, 0)
# If the number of columns in probs is too large for Triton to handle,
# we split the tensor and sample from each split separately, and then
# do an argmax+gather to combine the results.
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if n_splits > 1:
(sampled_tokens, sampled_logprobs,
sampled_modified_probs) = _multi_split_sample(
probs,
seeds,
n_splits,
sampled_tokens_size,
sampled_logprobs_size,
sample_indices,
logprobs=logprobs,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs)
else:
sampled_tokens = torch.empty(sampled_tokens_size,
dtype=torch.long,
device=probs.device)
sampled_logprobs = torch.empty(sampled_logprobs_size,
dtype=probs.dtype,
device=probs.device)
sampled_modified_probs = torch.empty(sampled_modified_probs_size,
dtype=probs.dtype,
device=probs.device)
n_samples = sample_indices.shape[0]
n_cols = probs.shape[1]
uniform_noise = seeded_uniform(n_samples,
max_best_of,
n_cols,
seeds=seeds.flatten(),
device=probs.device,
dtype=probs.dtype)
_sample(
probs,
logprobs,
sample_indices,
sampled_tokens,
sampled_logprobs,
sampled_modified_probs,
seeds,
uniform_noise,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
save_modified_probs=_save_modified_probs,
)
return (sampled_tokens, sampled_logprobs if save_logprobs else None,
sampled_modified_probs if _save_modified_probs else None)
def _sample(probs: torch.Tensor,
logprobs: torch.Tensor,
sample_indices: torch.Tensor,
output_samples: torch.Tensor,
output_logprobs: torch.Tensor,
output_modified_probs: torch.Tensor,
seeds: torch.Tensor,
uniform_noise: torch.Tensor,
*,
modify_greedy_probs: bool = False,
save_logprobs: bool = True,
save_modified_probs: bool = False) -> torch.Tensor:
"""Sample tokens from probs.
Args:
probs [batch_size, vocab_size]: probs to sample from.
logprobs [batch_size, vocab_size]: logprobs (used when
save_logprobsis True).
sample_indices [n]: Indices of the samples to use for each row of probs.
output_samples [n, n_best]: Output tensor to store samples in.
output_logprobs [n, n_best]: Output tensor to store logprobs in.
output_modified_probs [n, n_best]: Output tensor to store
probs of chosen tokens in (modified with noise).
seeds [n]: Seeds to use for sampling. If the seed is 0, we use
greedy sampling. Note this is ONLY used for determining
whether to use random sampling or not. The actual random
noise should be passed as uniform_noise.
uniform_noise [batch_size, n_best, vocab_size]: Uniform
noise to use for random sampling (will be converted
to exponential gumbel noise by the kernel).
modify_greedy_probs: If True, we modify the probs tensor in-place
to encode the sampling method used for each row. This is used
in speculative decoding. Only applies in greedy decoding.
save_logprobs: If True, we save the logprobs of the sampled tokens
in the output_logprobs tensor.
save_modified_probs: If True, we save the modified probs (with noise)
of the sampled tokens in the output_modified_probs tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
"""
n_samples = sample_indices.shape[0]
n_cols = probs.shape[1]
n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
# The block size is the smallest power of two greater than the number of
# columns in probs
block_size = triton.next_power_of_2(n_cols)
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if block_size >= 8192:
if is_hip():
num_warps = 16
else:
num_warps = 32
elif block_size >= 4096:
if is_hip():
num_warps = 8
else:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
# Enqueue kernel. The 1D launch grid is simple: we have one kernel
# instance per row of the probs matrix
_sample_triton[(n_samples, n_best)](
sample_indices,
output_samples,
output_logprobs,
output_modified_probs,
probs,
logprobs,
seeds,
uniform_noise,
output_samples.stride(0),
probs.stride(0),
uniform_noise.stride(0),
uniform_noise.stride(1) if n_best > 1 else 1,
n_samples,
n_cols,
n_best,
num_warps=num_warps,
block_size=block_size,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
save_modified_probs=save_modified_probs,
)
return output_samples, output_logprobs, output_modified_probs
@triton.jit
def _uniform_to_exponential(uniform_noise):
"""Convert uniform samples to exponential samples."""
# tl.rand returns values in [0, 1), so we clamp lower bound
# to _EPS to avoid log(0) and thus division by 0 later
lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
uniform_noise = tl.maximum(uniform_noise, lb)
# Use the inversion method to turn uniform samples
# into exponential samples
exponential_noise = -tl.log(uniform_noise)
return exponential_noise
@triton.jit
def _sample_triton(
sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
output_logprobs_ptr: torch.Tensor,
output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
uniform_noise_ptr: torch.Tensor, output_row_stride: int,
probs_row_stride: int, uniform_noise_row_stride: int,
uniform_noise_best_stride: int, n_samples: int, n_cols: int,
n_best: int, block_size: tl.constexpr,
modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
save_modified_probs: tl.constexpr):
# The rows are independent, so we parallelize across those
sample_idx = tl.program_id(0)
best_idx = tl.program_id(1)
# Load the row index from DRAM
row_idx = tl.load(sample_indices_ptr + sample_idx)
seed = tl.load(seeds_ptr + sample_idx)
uses_random_sampling = seed != 0
# The stride represents how much we need to increase the
# pointer to advance 1 row
row_start_ptr = probs_ptr + row_idx * probs_row_stride
# The block size is the next power of two greater than n_cols,
# so we can fit each row in a single block
col_offsets = tl.arange(0, block_size)
# Load the row into SRAM, using a mask since block_size may be > than n_cols
row = tl.load(row_start_ptr + col_offsets,
mask=col_offsets < n_cols,
other=float("-inf"))
if uses_random_sampling:
uniform_noise_start_ptr = (uniform_noise_ptr +
sample_idx * uniform_noise_row_stride +
best_idx * uniform_noise_best_stride)
uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
mask=col_offsets < n_cols,
other=0.5)
exponential_noise = _uniform_to_exponential(uniform_noise)
row /= exponential_noise
sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
# clamp sampled token to n_cols - 1
# this should not be necessary, but we do it
# just in case
if sampled_token >= n_cols:
sampled_token = n_cols - 1
# Write back output to DRAM
output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
best_idx)
tl.store(output_row_start_ptr, sampled_token)
if modify_greedy_probs: # noqa
if not uses_random_sampling:
# Set the probability of the sampled token to 1, all other
# tokens to zero. This is used in speculative decoding where
# the sampling method must be encoded within the sampled
# probability distributions.
row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
tl.store(row_start_ptr + col_offsets,
row,
mask=col_offsets < n_cols)
if save_modified_probs:
output_row_start_ptr = (output_modified_probs_ptr +
sample_idx * output_row_stride + best_idx)
tl.store(output_row_start_ptr, sampled_value)
if save_logprobs:
# Load the row into SRAM, using a mask since block_size
# may be > than n_cols
sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
sampled_token)
# Write back output to DRAM
output_row_start_ptr = (output_logprobs_ptr +
sample_idx * output_row_stride + best_idx)
tl.store(output_row_start_ptr, sampled_logprob)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional, List
import torch
import torch.jit
import torch.nn.functional as F
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeDeterministicBaseSampler)
from vllm.logger import init_logger
logger = init_logger(__name__)
class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
"""Apply typical acceptance sampling as described in section 3.3.1 in
"MEDUSA: Simple LLM Inference Acceleration Framework with
Multiple Decoding Heads"
https://arxiv.org/pdf/2401.10774
"""
def __init__(
self,
posterior_threshold: float,
posterior_alpha: float,
strict_mode: bool = False,
):
"""Create a Typical Acceptance Sampler.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
posterior_threshold : A threshold value that sets a lower bound
on the posterior probability of a token in target model for it
to be accepted.
posterior_alpha : A scaling factor for the entropy-based
threshold in typical acceptance sampling.
"""
self._posterior_threshold = posterior_threshold
self._posterior_alpha = posterior_alpha
super().__init__(strict_mode=strict_mode)
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def forward(
self,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
cart_candidates: Optional[torch.Tensor] = None,
best_candidates: Optional[List] = None,
accept_lengths: Optional[List] = None,
first_step_flags: Optional[List] = None,
) -> torch.Tensor:
"""Sample token ids using typical acceptance sampling. This accepts
or rejects tokens proposed by the draft model using the probability
of each token according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be
accepted.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: This parameter is unused by the acceptance sampler.
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
cart_candidates: tree-style cartesian candidates
best_candidates: pending to write best candidates index
accept_lengths: pending to write accept lengths
first_step_flags: whether this is the first decoding step
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_input(target_with_bonus_probs,
draft_token_ids, bonus_token_ids)
if not self.tree_decoding:
target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids)
recovered_token_ids = self._get_recovered_token_ids(target_probs)
output_token_ids = self._create_output(accepted, recovered_token_ids,
draft_token_ids,
bonus_token_ids)
else:
assert cart_candidates is not None
target_probs = target_with_bonus_probs
output_token_ids = self._evaluate_accepted_tokens_tree_style(target_probs,
draft_token_ids,
cart_candidates,
best_candidates,
accept_lengths,
first_step_flags)
return output_token_ids
def _evaluate_accepted_tokens_tree_style(self, target_probs, draft_token_ids,
cart_candidates, output_best_candidates,
accept_lengths, first_step_flags):
r"""
Evaluates and returns a mask of accepted tokens based on the
posterior probabilities.
Parameters:
----------
target_probs : torch.Tensor
A tensor of shape (batch_size, k, vocab_size) representing
the probabilities of each token in the vocabulary for each
position in the proposed sequence. This is the distribution
generated by the target model.
draft_token_ids : torch.Tensor
A tensor of shape (batch_size, k) representing the proposed
token ids.
cart_candidates : torch.Tensor
A tensor of shape (batch_size, retrieve_size, tree_depth)
representing the cart candidates of tree proposals.
A draft token_id x_{n+k} is accepted if it satisfies the
following condition
.. math::
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
\min \left( \epsilon, \delta * \exp \left(
-H(p_{\text{original}}(
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
where :math:`p_{\text{original}}` corresponds to target_probs
and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters
specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given
draft token ids based on the provided target probabilities. It
calculates the entropy of the posterior distribution and determines
a dynamic threshold for each token position using the provided
posterior_threshold and posterior_alpha values. The method then
returns a boolean mask indicating which tokens can be accepted.
Returns:
-------
torch.Tensor
A boolean tensor of shape (batch_size, k) where each element
indicates whether the corresponding draft token has been accepted
or rejected. True indicates acceptance and false indicates
rejection.
"""
target_probs = target_probs[:, :, :-1]
device = target_probs.device
batch_size = cart_candidates.shape[0]
candidates_prob = torch.gather(
target_probs, dim=-1, index=cart_candidates[:, :, 1:].unsqueeze(-1)
).squeeze(-1) # [batch_size, retrieve_size, max_depth]
posterior_entropy = -torch.sum(
target_probs * torch.log(target_probs + 1e-5), dim=-1
) # torch.sum(torch.log(*)) is faster than torch.prod [batch_size, retrieve_size, max_depth]
threshold = torch.minimum(
torch.ones_like(posterior_entropy) * self._posterior_threshold,
torch.exp(-posterior_entropy) * self._posterior_alpha,
)
posterior_mask = candidates_prob > threshold # [batch_size, retrieve_size, max_depth]
candidates_accept_length = (torch.cumprod(posterior_mask, dim=2)).sum(dim=-1) # [batch_size, retrieve_size]
# Choose the best candidate based on the evaluated posterior probabilities
accept_length, _ = candidates_accept_length.max(dim=-1) # [batch_size]
if torch.any(accept_length > 0):
valid_index = (candidates_accept_length == accept_length.unsqueeze(-1)).unsqueeze(-1) # [batch_size, retrieve_size, 1]
candidates_prob = candidates_prob * valid_index # [batch_size, retrieve_size, max_depth]
valid_index = torch.arange(candidates_prob.shape[-1], device=device).unsqueeze(0).unsqueeze(0).repeat(
batch_size, candidates_prob.shape[1], 1) # [batch_size, retrieve_size, max_depth]
valid_index = (valid_index < accept_length.unsqueeze(1).unsqueeze(2).repeat(1, candidates_prob.shape[1], 1)) # [batch_size, retrieve_size, 1]
candidates_prob = candidates_prob*valid_index # [batch_size, retrieve_size, max_depth]
# add 1e-3 to avoid zero value
likelihood = torch.sum(torch.log(candidates_prob + 1e-3), dim=-1) # [batch_size, retrieve_size]
best_candidate = torch.argmax(likelihood, dim=-1) # [batch_size]
else:
# Choose the best candidate
best_candidate = torch.zeros((batch_size), dtype=torch.long, device=device) # [batch_size]
k = draft_token_ids.shape[-1]
output_token_id_list = []
accept_length_list = accept_length.cpu().tolist()
#logger.info("accept_length:%s", accept_length_list)
for i in range(batch_size):
output_best_candidates.append(best_candidate[i])
accept_lengths.append(accept_length_list[i])
if not first_step_flags[i]:
select_indices = cart_candidates[i, best_candidate[i], : accept_length[i] + 1]
select_indices = F.pad(select_indices, (0, k - 1 - accept_length[i]), 'constant', -1)
else:
select_indices = cart_candidates[i, best_candidate[i], 1 : accept_length[i] + 1]
select_indices = F.pad(select_indices, (0, k - accept_length[i]), 'constant', -1)
output_token_id_list.append(select_indices)
return torch.stack(output_token_id_list, dim=0)
def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
r"""
Evaluates and returns a mask of accepted tokens based on the
posterior probabilities.
Args:
target_probs (torch.Tensor): A tensor of shape
(batch_size, k, vocab_size) representing the probabilities of
each token in the vocabulary for each position in the proposed
sequence. This is the distribution generated by the target
model.
draft_token_ids (torch.Tensor): A tensor of shape (batch_size, k)
representing the proposed token ids.
A draft token_id x_{n+k} is accepted if it satisfies the
following condition
$$
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
\min \left( \epsilon, \delta * \exp \left(
-H(p_{\text{original}}(
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
$$
where $p_{\text{original}}$ corresponds to target_probs
and $\epsilon$ and $\delta$ correspond to hyperparameters
specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given
draft token ids based on the provided target probabilities. It
calculates the entropy of the posterior distribution and determines
a dynamic threshold for each token position using the provided
posterior_threshold and posterior_alpha values. The method then
returns a boolean mask indicating which tokens can be accepted.
Returns:
torch.Tensor: A boolean tensor of shape (batch_size, k) where each
element indicates whether the corresponding draft token has
been accepted or rejected. True indicates acceptance and false
indicates rejection.
"""
device = target_probs.device
candidates_prob = torch.gather(
target_probs, dim=-1,
index=draft_token_ids.unsqueeze(-1)).squeeze(-1)
# A small constant added to prevent computing the logarithm of zero,
# which can lead to undefined values.
epsilon = 1e-5
posterior_entropy = -torch.sum(
target_probs * torch.log(target_probs + epsilon), dim=-1)
threshold = torch.minimum(
torch.ones_like(posterior_entropy, device=device) *
self._posterior_threshold,
torch.exp(-posterior_entropy) * self._posterior_alpha,
)
accepted_mask = candidates_prob > threshold
return accepted_mask
def _get_recovered_token_ids(self, target_probs):
"""
The recovered token ids will fill the first unmatched token
by the target token.
Args:
target_probs (torch.Tensor): A tensor of shape
(batch_size, k, vocab_size) containing the target probability
distribution.
Returns:
torch.Tensor: A tensor of shape (batch_size, k) with the recovered
token ids which are selected from target probs.
"""
max_indices = torch.argmax(target_probs, dim=-1)
return max_indices
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: SIM117
import collections
import copy
import dataclasses
import fnmatch
import glob
import inspect
import itertools
import math
import os
import time
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
Tuple, cast)
import gguf
import huggingface_hub
import numpy as np
import torch
from huggingface_hub import HfApi
from torch import nn
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.attention import Attention
from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
VllmConfig, set_current_vllm_config)
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase)
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (ParamMapping,
configure_quant_config,
get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference, get_gguf_extra_tensor_names,
get_lock, gguf_quant_weights_iterator, initialize_dummy_weights,
np_cache_weights_iterator, pt_weights_iterator,
runai_safetensors_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.transformers_utils.s3_utils import glob as s3_glob
from vllm.transformers_utils.utils import is_s3
from vllm.utils import is_pin_memory_available
@contextmanager
def device_loading_context(module: torch.nn.Module,
target_device: torch.device):
if target_device.type == "cpu":
# If target is CPU, no need to move anything
yield module
return
original_device_states: Dict[str, torch.device] = {}
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
# Parameters already on target device are not touched
try:
yield module
finally:
# Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available()
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
else:
p.data = p.data.to(original_device)
# New parameters or parameters already on target device are untouched
logger = init_logger(__name__)
def _initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
model_class: Optional[type[nn.Module]] = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
if model_class is None:
model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None:
configure_quant_config(vllm_config.quant_config, model_class)
signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config, check_compile=True):
return model_class(vllm_config=vllm_config, prefix=prefix)
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class,
)
# try to be compatible with old-style model class
kwargs = {}
if "prefix" in all_params:
kwargs["prefix"] = prefix
if "config" in all_params:
kwargs["config"] = model_config.hf_config
if "cache_config" in all_params:
kwargs["cache_config"] = vllm_config.cache_config
if "quant_config" in all_params:
kwargs["quant_config"] = vllm_config.quant_config
if "lora_config" in all_params:
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
if "parallel_config" in all_params:
kwargs["parallel_config"] = vllm_config.parallel_config
with set_current_vllm_config(vllm_config, check_compile=True):
return model_class(**kwargs)
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
target_device: torch.device) -> None:
for _, module in model.named_modules():
if isinstance(module, QKVCrossParallelLinear):
# NOTE(Isotr0py): special case for cross QKV layer because
# q and kv proj aren't registered as submodules intentionally
module.process_weights_after_loading()
continue
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Currently only used by MLA.
# NOTE: This intentionally happens after other modules so we can easily
# decompress the weights for MLA.
for _, module in model.named_modules():
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
class BaseModelLoader(ABC):
"""Base class for model loaders."""
def __init__(self, load_config: LoadConfig):
self.load_config = load_config
@abstractmethod
def download_model(self, model_config: ModelConfig) -> None:
"""Download a model so that it can be immediately loaded."""
raise NotImplementedError
@abstractmethod
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
"""Load a model with the given configurations."""
raise NotImplementedError
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
@dataclasses.dataclass
class Source:
"""A source for weights."""
model_or_path: str
"""The model ID or path."""
revision: Optional[str]
"""The optional model revision."""
prefix: str = ""
"""A prefix to prepend to all weights."""
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
allow_patterns_overrides: Optional[list[str]] = None
"""If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights: float = 0.0
counter_after_loading_weights: float = 0.0
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _maybe_download_from_modelscope(
self, model: str, revision: Optional[str]) -> Optional[str]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
if not os.path.exists(model):
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model, self.load_config.download_dir):
model_path = snapshot_download(
model_id=model,
cache_dir=self.load_config.download_dir,
local_files_only=huggingface_hub.constants.
HF_HUB_OFFLINE,
revision=revision,
ignore_file_pattern=self.load_config.ignore_patterns,
)
else:
model_path = model
return model_path
return None
def _prepare_weights(
self,
model_name_or_path: str,
revision: Optional[str],
fall_back_to_pt: bool,
allow_patterns_overrides: Optional[list[str]],
) -> Tuple[str, List[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path = (self._maybe_download_from_modelscope(
model_name_or_path, revision) or model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"]
elif (load_format == LoadFormat.SAFETENSORS
or load_format == LoadFormat.FASTSAFETENSORS):
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL:
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == LoadFormat.PT:
allow_patterns = ["*.pt"]
elif load_format == LoadFormat.NPCACHE:
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if allow_patterns_overrides is not None:
allow_patterns = allow_patterns_overrides
if not is_local:
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
else:
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, source: "Source"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt,
source.allow_patterns_overrides)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
source.model_or_path,
self.load_config.download_dir,
hf_folder,
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import torch_xla.core.xla_model as xm
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()
weights_iterator = _xla_weights_iterator(weights_iterator)
elif current_platform.is_hpu():
import habana_frameworks.torch.core as htcore
def _hpu_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
htcore.mark_step()
weights_iterator = _hpu_weights_iterator(weights_iterator)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix.
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)
def get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True),
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
None),
)
yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(
Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()),
)
for source in secondary_weights:
yield from self._get_weights_iterator(source)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model,
model_config.revision,
fall_back_to_pt=True,
allow_patterns_overrides=None)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights -
self.counter_before_loading_weights)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
_process_weights_after_loading(model, model_config, target_device)
return model.eval()
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
_process_weights_after_loading(model, model_config, target_device)
return model.eval()
class TensorizerLoader(BaseModelLoader):
"""Model loader using CoreWeave's tensorizer library."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
self.tensorizer_config = load_config.model_loader_extra_config
else:
self.tensorizer_config = TensorizerConfig(
**load_config.model_loader_extra_config)
def _verify_config(self, model_config: ModelConfig,
parallel_config: ParallelConfig):
self.tensorizer_config.verify_with_model_config(model_config)
self.tensorizer_config.verify_with_parallel_config(parallel_config)
def _get_weights_iterator(
self, ) -> Generator[Tuple[str, torch.Tensor], None, None]:
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
return tensorizer_weights_iterator(tensorizer_args)
def _load_model_serialized_cpu(
self,
vllm_config: VllmConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.
This is only necessary when the model isn't vLLM-tensorized (see
examples/other/tensorize_vllm_model.py) This should still
be faster than default HuggingFace loading, but will be slower than
loading a vLLM-tensorized model.
"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(vllm_config=vllm_config)
model.load_weights(self._get_weights_iterator())
return model.eval()
def _load_model_serialized(
self,
vllm_config: VllmConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer.
Expects a vLLM-tensorized model. See the
examples/other/tensorize_vllm_model.py example script
for serializing vLLM models."""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype
model = load_with_tensorizer(tensorizer_config,
vllm_config=vllm_config)
return model.eval()
def download_model(self, model_config: ModelConfig) -> None:
self.tensorizer_config.verify_with_model_config(model_config)
with self.tensorizer_config.open_stream():
pass
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)
if parallel_config.tensor_parallel_size > 1:
from vllm.distributed import get_tensor_model_parallel_rank
self.tensorizer_config.tensorizer_uri = (
self.tensorizer_config.tensorizer_uri %
get_tensor_model_parallel_rank())
if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(vllm_config=vllm_config)
return self._load_model_serialized_cpu(vllm_config=vllm_config)
@staticmethod
def save_model(
model: torch.nn.Module,
tensorizer_config: TensorizerConfig,
) -> None:
serialize_vllm_model(
model=model,
tensorizer_config=tensorizer_config,
)
class ShardedStateLoader(BaseModelLoader):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/offline_inference/save_sharded_state.py` for creating a sharded
checkpoint.
"""
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
def __init__(self,
load_config: LoadConfig,
runai_model_streamer: bool = False):
super().__init__(load_config)
self.runai_model_streamer = runai_model_streamer
extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
if extra_config:
raise ValueError(f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{load_config.model_loader_extra_config.keys()}")
@staticmethod
def _filter_subtensors(
tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = (
collections.defaultdict(list))
for key, tensor in tensors.items():
if tensor.numel():
ptr = tensor.untyped_storage().data_ptr()
same_storage_groups[tensor.device, ptr].append((key, tensor))
def get_end_ptr(tensor: torch.Tensor) -> int:
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
result: Dict[str, torch.Tensor] = {}
for group in same_storage_groups.values():
for k, t in group:
a, b = t.data_ptr(), get_end_ptr(t)
for k2, t2 in group:
if not t2.is_contiguous():
continue
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
if a < a2 or b2 < b:
continue
if a2 < a or b < b2 or not t.is_contiguous():
break # t2 covers strictly more memory than t.
if k2 < k:
# Same tensors, keep the one with the smaller key.
break
else:
result[k] = t
return result
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]):
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
return model_name_or_path
else:
allow_patterns = ["*.safetensors"]
return download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
from vllm.distributed import get_tensor_model_parallel_rank
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
local_model_path = model_weights
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
_process_weights_after_loading(model, model_config,
target_device)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
self.pattern.format(rank=rank, part="*"),
)
filepaths = []
if is_s3(local_model_path):
file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}"
filepaths = s3_glob(path=local_model_path,
allow_pattern=[file_pattern])
else:
filepaths = glob.glob(pattern)
if not filepaths:
# TODO: support un-sharded checkpoints too
raise ValueError(
f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!")
state_dict = self._filter_subtensors(model.state_dict())
for key, tensor in self.iterate_over_files(filepaths):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into "
"parameter '%s' of shape %s",
tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict:
raise ValueError(
f"Missing keys {tuple(state_dict)} in loaded state!")
return model.eval()
def iterate_over_files(
self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]:
if self.runai_model_streamer:
yield from runai_safetensors_weights_iterator(paths, True)
else:
from safetensors.torch import safe_open
for path in paths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
yield key, tensor
@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from safetensors.torch import save_file
from vllm.distributed import get_tensor_model_parallel_rank
if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN
rank = get_tensor_model_parallel_rank()
part_idx = 0
total_size = 0
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
state_dict_part: Dict[str, torch.Tensor] = {}
for key, tensor in state_dict.items():
param_size = tensor.nelement() * tensor.element_size()
if max_size is not None and total_size + param_size > max_size:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)
part_idx += 1
total_size = 0
state_dict_part = {}
state_dict_part[key] = tensor
total_size += param_size
if len(state_dict_part) > 0:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""
possible_config_file_names = ["adapter_config.json"]
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
# Save the module names without sharding.
self.unsharded_weights_modules: List[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: List[str] = []
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: List[str] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
def _get_weight_files(
self,
model_name_or_path: str,
allowed_patterns: List[str],
revision: Optional[str] = None,
) -> Tuple[str, List[str], str]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
is_local = os.path.isdir(model_name_or_path)
if is_local:
for pattern in allowed_patterns:
weight_files = glob.glob(
os.path.join(model_name_or_path, pattern))
if weight_files:
return model_name_or_path, weight_files, pattern
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
for pattern in allowed_patterns:
matching_files = fnmatch.filter(repo_files, pattern)
if matching_files:
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
return hf_folder, glob.glob(
os.path.join(hf_folder, pattern)), pattern
raise RuntimeError(
f"No model weights found in: `{model_name_or_path}`")
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]) -> Tuple[List[str], bool]:
"""Prepare weight files for the model."""
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision)
use_safetensors = matched_pattern == "*.safetensors"
is_local = os.path.isdir(model_name_or_path)
index_file = SAFE_WEIGHTS_INDEX_NAME
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_weights_files, use_safetensors
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name = self.weight_mapper(org_name)
yield org_name, mapped_name, param
def _get_quantized_weights_iterator(
self,
model_name_or_path: str,
revision: Optional[str],
pre_quant: bool,
load_8bit: bool,
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
as well as the quantization state dictionary."""
# only load the bitsandbytes module when needed
try:
import bitsandbytes
if bitsandbytes.__version__ < "0.45.3":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.45.3.")
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.45.3 via "
"`pip install bitsandbytes>=0.45.3` to use "
"bitsandbytes quantizer.") from err
hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision)
quant_state_dict: Dict[str, Any] = {}
if pre_quant:
if load_8bit:
return self._quantized_8bit_generator(
hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
else:
return self._quantized_4bit_generator(
hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
return self._unquantized_generator(hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
def _is_8bit_weight_name(self, weight_name: str):
quantized_suffix = {".scb", ".weight_format"}
return any(weight_name.lower().endswith(suffix)
for suffix in quantized_suffix)
def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = {
"absmax",
"quant_map",
"nested_absmax",
"nested_quant_map",
"bitsandbytes",
}
suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix)
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if not mapped_weight_name.lower().endswith(".scb"):
continue
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
quant_state_dict[weight_key] = weight_tensor
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(mapped_weight_name):
continue
if mapped_weight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield org_weight_name, weight_tensor
else:
yield org_weight_name, weight_tensor
def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
from bitsandbytes.functional import QuantState
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
temp_state_dict = {}
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in weight_iterator:
if not self._is_4bit_weight_name(mapped_weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if "quant_state.bitsandbytes" in mapped_weight_name:
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
else:
temp_state_dict[mapped_weight_name] = weight_tensor
# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
temp_state_dict: Dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
return QuantState.from_dict(quant_state,
device=current_platform.device_type)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(mapped_weight_name):
continue
if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
in temp_state_dict) or (
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
in temp_state_dict):
quant_state = _parse_quant_state(mapped_weight_name,
temp_state_dict)
quant_state_dict[mapped_weight_name] = quant_state
yield org_weight_name, weight_tensor
else:
yield org_weight_name, weight_tensor
def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
from bitsandbytes.functional import quantize_4bit
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if any(target_module in mapped_weight_name
for target_module in self.target_modules
) and mapped_weight_name.endswith(".weight"):
# Without sharding
if any(
mapped_weight_name.startswith(module)
for module in self.unsharded_weights_modules):
weight_sub_tensor = weight_tensor
# Shard by column
elif any(
mapped_weight_name.startswith(module)
for module in self.column_sharded_weights_modules):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[...,
start_index:end_index]
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif any(
mapped_weight_name.startswith(module)
for module in self.maybe_fused_weights_modules):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes = next(
(sizes for module, sizes in
self.maybe_fused_weights_modules.items()
if mapped_weight_name.startswith(module)))
total_size = weight_tensor.size(0)
assert total_size == sum(total_shard_sizes)
# get the start/end index of each shard weight tensor
total_start_index = list(
itertools.accumulate([0] + total_shard_sizes))[:-1]
shard_weights_index = [(
idx + size // tp_size * tp_rank,
idx + size // tp_size * (tp_rank + 1),
) for idx, size in zip(total_start_index,
total_shard_sizes)]
# slice and reorder the weight tensor
weight_tensor = [
weight_tensor[start_index:end_index, ...]
for start_index, end_index in shard_weights_index
]
weight_sub_tensor = torch.cat(weight_tensor, dim=0)
# Shard by row
else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[start_index:end_index,
...]
# bitsandbytes requires data in GPU
if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.cuda()
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if loaded_weight.is_contiguous() is False:
loaded_weight = loaded_weight.contiguous()
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
compress_statistics=True,
quant_type="nf4",
)
quant_state_dict[mapped_weight_name] = quant_state
else:
processed_weight = weight_tensor
yield org_weight_name, processed_weight
def _get_bnb_target_modules(self, model: nn.Module) -> None:
for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
if modules_info := self.modules_mapping.get_sub_modules(name):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
for sub_name in sub_modules:
self.target_modules.append(
name.replace(rep_name, sub_name))
# Add original module name even if the module has stacked map,
# in case model has a mixture of disk-merged and disk-splitted
# weights with same last name.
self.target_modules.append(name)
assert (self.target_modules
), "vllm currently does not support BNB quantization for"
f" {type(model).__name__}"
def _load_weights(self, model_config: ModelConfig,
model: nn.Module) -> None:
if not hasattr(model, "load_weights"):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}.")
if not hasattr(model, "packed_modules_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found.")
self.modules_mapping = ParamMapping(
copy.deepcopy(model.packed_modules_mapping))
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
self._get_bnb_target_modules(model)
for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation.
if isinstance(module, (ReplicatedLinear, )):
self.unsharded_weights_modules.append(name)
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
# fused weights on disk. We need to use the output sizes of these
# modules to shard the weights correctly.
elif isinstance(module,
(QKVParallelLinear, MergedColumnParallelLinear)):
self.maybe_fused_weights_modules[name] = module.output_sizes
# In TP, these weights are partitioned along the column
# dimension (dim=-1)
elif isinstance(module, (RowParallelLinear, )):
self.column_sharded_weights_modules.append(name)
self.model_type = type(model).__name__
logger.info("Loading weights with BitsAndBytes quantization. "
"May take a while ...")
quant_config = getattr(model_config.hf_config, "quantization_config",
None)
pre_quant = False
if quant_config is not None:
quant_method = quant_config.get("quant_method")
if quant_method == "bitsandbytes":
pre_quant = True
else:
raise ValueError(
f"BitsAndBytes loader does not support {quant_method} "
"quantization")
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with tensor parallelism is not "
"supported. Please try with pipeline parallelism.")
load_8bit = False
if pre_quant:
load_8bit = quant_config.get("load_in_8bit", False)
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(model_config.model,
model_config.revision,
pre_quant, load_8bit))
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
torch.cuda.empty_cache()
param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
from vllm.model_executor.models.utils import is_pp_missing_parameter
for quant_param_name in quant_state_dict:
if is_pp_missing_parameter(quant_param_name, model):
continue
non_stacked_param_name = quant_param_name
shard_index = 0
for shard_name, (
weight_name,
index,
) in self.modules_mapping.inverse_packed_mapping.items():
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
shard_pos = quant_param_name.find(shard_name)
can_correct_rename = (shard_pos
> 0) and (quant_param_name[shard_pos - 1]
== ".")
# If the quant_param_name is packed, it won't occur in the
# param_dict before renaming.
new_quant_param_name = quant_param_name.replace(
shard_name, weight_name)
need_rename = (quant_param_name not in param_dict) \
and (new_quant_param_name in param_dict)
if can_correct_rename and need_rename:
shard_index = index
quant_param_name = new_quant_param_name
break
# Models like Clip/Siglip may skip some layers in initialization,
# causing unused quant_param_name in state_dict.
if quant_param_name not in param_dict:
continue
if quant_param_name not in stacked_quant_state_dict:
stacked_quant_state_dict[quant_param_name] = {}
stacked_quant_state_dict[quant_param_name][shard_index] = (
quant_state_dict[non_stacked_param_name])
# save quant_states and offsets as the attributes of the parameters
for param_name, param in param_dict.items():
if param_name in stacked_quant_state_dict:
quant_states = stacked_quant_state_dict[param_name]
set_weight_attrs(param, {"bnb_quant_state": quant_states})
pack_ratio = getattr(param, "pack_factor", -1)
if pack_ratio == -1:
raise ValueError(
f"pack_factor not set for parameter {param_name}.")
num_elements = [0] * len(quant_states)
for seq, quant_state in quant_states.items():
num_elements[seq] = (math.prod(quant_state.shape) //
pack_ratio)
offsets = np.concatenate(([0], np.cumsum(num_elements)))
# Make torch infer_schema happy
offsets = torch.tensor(offsets).cpu()
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
if load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)})
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(vllm_config=vllm_config)
self._load_weights(model_config, model)
return model.eval()
class GGUFModelLoader(BaseModelLoader):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _prepare_weights(self, model_name_or_path: str):
if os.path.isfile(model_name_or_path):
return model_name_or_path
else:
raise ValueError(f"{model_name_or_path} is not a file.")
def _get_gguf_weights_map(self, model_config: ModelConfig):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
`blk.N.BB.weight` and `blk.N.BB.bias`
where N signifies the block number of a layer, and BB signifies the
attention/mlp layer components.
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config = model_config.hf_config
model_type = config.model_type
gguf_to_hf_name_map = {}
# hack: ggufs have a different name than transformers
if model_type == "cohere":
model_type = "command-r"
if model_type in ("deepseek_v3", "deepseek_v2"):
model_type = "deepseek2"
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for idx in range(config.num_hidden_layers):
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
arch = key
break
if arch is None:
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
num_layers = config.num_hidden_layers
name_map = gguf.get_tensor_name_map(arch, num_layers)
with torch.device("meta"):
dummy_model = AutoModelForCausalLM.from_config(
config, trust_remote_code=model_config.trust_remote_code)
state_dict = dummy_model.state_dict()
for hf_name in state_dict:
name, suffix = hf_name.rsplit(".", 1)
gguf_name = name_map.get_name(name)
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
return gguf_to_hf_name_map
def _get_weights_iterator(
self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
return gguf_quant_weights_iterator(model_name_or_path,
gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
if "lm_head.weight" in get_gguf_extra_tensor_names(
local_model_path, gguf_weights_map):
model_config.hf_config.update({"tie_word_embeddings": True})
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map))
_process_weights_after_loading(model, model_config, target_device)
return model
class RunaiModelStreamerLoader(BaseModelLoader):
"""
Model loader that can load safetensors
files from local FS or S3 bucket.
"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config
if ("concurrency" in extra_config
and isinstance(extra_config.get("concurrency"), int)):
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
extra_config.get("concurrency"))
if ("memory_limit" in extra_config
and isinstance(extra_config.get("memory_limit"), int)):
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
extra_config.get("memory_limit"))
runai_streamer_s3_endpoint = os.getenv(
'RUNAI_STREAMER_S3_ENDPOINT')
aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
if (runai_streamer_s3_endpoint is None
and aws_endpoint_url is not None):
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]) -> List[str]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_s3_path = is_s3(model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME
hf_folder = (model_name_or_path if
(is_local or is_s3_path) else download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[safetensors_pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
))
if is_s3_path:
hf_weights_files = s3_glob(path=hf_folder,
allow_pattern=[safetensors_pattern])
else:
hf_weights_files = glob.glob(
os.path.join(hf_folder, safetensors_pattern))
if not is_local and not is_s3_path:
download_safetensors_index_file_from_hf(
model_name_or_path, index_file, self.load_config.download_dir,
revision)
if not hf_weights_files:
raise RuntimeError(
f"Cannot find any safetensors model weights with "
f"`{model_name_or_path}`")
return hf_weights_files
def _get_weights_iterator(
self, model_or_path: str,
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
def download_model(self, model_config: ModelConfig) -> None:
"""Download model if necessary"""
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
"""Perform streaming of the model to destination"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
model.load_weights(
self._get_weights_iterator(model_weights,
model_config.revision))
_process_weights_after_loading(model, model_config, target_device)
return model.eval()
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)
if load_config.load_format == LoadFormat.DUMMY:
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.TENSORIZER:
return TensorizerLoader(load_config)
if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)
if load_config.load_format == LoadFormat.BITSANDBYTES:
return BitsAndBytesModelLoader(load_config)
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
return ShardedStateLoader(load_config, runai_model_streamer=True)
return DefaultModelLoader(load_config)
......@@ -116,6 +116,7 @@ class Ernie4_5_MoeMoE(nn.Module):
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.moe_num_experts}.")
self.gate = ReplicatedLinear(config.hidden_size,
config.moe_num_experts,
bias=False,
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/THUDM/GLM-4
"""Inference-only GLM-4v model visual encoder compatible with THUDM weights."""
from argparse import Namespace
from typing import Optional
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class PatchEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.proj = nn.Conv2d(config.in_channels,
config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size)
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
self.position_embedding = nn.Embedding(config.num_positions,
config.hidden_size)
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
images = images.to(device=self.proj.weight.device,
dtype=self.proj.weight.dtype)
x = self.proj(images)
x = x.flatten(2).transpose(1, 2)
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x += self.position_embedding.weight.unsqueeze(0)
return x
class Attention(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.hidden_size = config.hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_rank = config.num_heads // self.tp_size
self.head_dim = config.hidden_size // config.num_heads
self.scale = self.head_dim**-0.5
self.query_key_value = QKVParallelLinear(
config.hidden_size,
self.head_dim,
config.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
self.scale)
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
q, k, v = qkv.chunk(3, dim=-1)
out = self.attn(q, k, v)
output, _ = self.dense(out)
output = self.output_dropout(output)
return output
class MLP(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc1(x)
x = self.activation_fn(x)
x, _ = self.fc2(x)
return x
class TransformerLayer(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.input_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.mlp = MLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states):
attention_input = hidden_states
attention_output = self.input_layernorm(
self.attention(attention_input))
hidden_states = attention_input + attention_output
mlp_input = hidden_states
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
output = mlp_input + mlp_output
return output
class Transformer(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
def forward(self, hidden_states):
for layer_module in self.layers:
hidden_states = layer_module(hidden_states)
return hidden_states
class GLU(nn.Module):
def __init__(
self,
config,
in_features,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
"""
The original implementation is the same as:
```python
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
self.gate_proj = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
```
```
gate_proj_output, _ = self.gate_proj(x)
dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
```
We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
```
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config
)
```
```
x, _ = self.merged_proj(x)
```
"""
super().__init__()
self.linear_proj = ReplicatedLinear(in_features,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.norm1 = nn.LayerNorm(config.hidden_size)
self.act1 = nn.GELU()
self.act2 = SiluAndMul()
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size, [config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.merged_proj")
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h")
def forward(self, x):
x, _ = self.linear_proj(x)
x = self.act1(self.norm1(x))
x, _ = self.merged_proj(x)
x = self.act2(x)
x, _ = self.dense_4h_to_h(x)
return x
class EVA2CLIPModel(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
vision_config = Namespace(**config.vision_config)
self.patch_embedding = PatchEmbedding(vision_config)
self.transformer = Transformer(vision_config,
quant_config=quant_config,
prefix=f"{prefix}.transformer")
self.linear_proj = GLU(config,
in_features=config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
out_channels=config.hidden_size,
kernel_size=2,
stride=2)
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.scaling_factor = vision_config.scaling_factor
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
x = self.patch_embedding(images)
x = self.transformer(x)
x = x[:, 1:]
b, s, h = x.shape
grid_size = int(s**0.5)
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
x = self.conv(x)
x = x.flatten(2).transpose(1, 2)
x = self.linear_proj(x)
boi = self.boi.expand(x.shape[0], -1, -1)
eoi = self.eoi.expand(x.shape[0], -1, -1)
x = torch.cat((boi, x, eoi), dim=1)
x = x / self.scaling_factor
return x
\ No newline at end of file
......@@ -25,6 +25,7 @@ import torch
from torch import nn
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2Config, Idefics2VisionConfig)
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
......
......@@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm import _custom_ops as ops
from vllm.distributed import tensor_model_parallel_all_gather, tensor_model_parallel_gather
from vllm import envs
SQRT2 = 2**0.5
......@@ -215,7 +216,7 @@ class MLPSpeculator(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn:
if self.use_llama_nn or envs.VLLM_USE_NN:
if (os.environ['LM_NN'] == '1' and "head" in name) or "proj" in name:
_weight = torch.zeros_like(param.data)
ori_shape =_weight.shape
......
......@@ -436,7 +436,7 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
self.patch_size)
x=x.to(memory_format=torch.channels_last_3d)
# x=x.to(memory_format=torch.channels_last_3d)
x = self.proj(x).view(L, self.hidden_size)
return x
......
......@@ -89,7 +89,6 @@ _TEXT_GENERATION_MODELS = {
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
......@@ -132,8 +131,6 @@ _TEXT_GENERATION_MODELS = {
"TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
"Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"),
"Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
......
......@@ -96,13 +96,11 @@ class CpuPlatform(Platform):
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
logger.info("Using CPU MLA backend.")
return "vllm.attention.backends.cpu_mla.CPUMLABackend"
raise NotImplementedError("MLA is not supported on CPU.")
logger.info("Using Torch SDPA backend.")
if use_v1:
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
else:
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
if not use_v1:
raise ValueError("CPU backend only supports V1.")
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
......@@ -185,26 +183,14 @@ class CpuPlatform(Platform):
parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend = "mp"
if parallel_config.worker_cls == "auto":
if vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.cpu_worker.CPUWorker"
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.cpu_worker.CPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.cpu_worker.CPUWorker"
parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
# Note: workaround for v1 gpu_model_runner
from vllm.config import CompilationLevel
vllm_config.compilation_config.cudagraph_capture_sizes = []
compilation_config = vllm_config.compilation_config
if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE):
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
# Note: vLLM V1 is using PIECEWISE level compilation, which will
# take time to compile kernels just-in-time with the inductor
......
......@@ -75,12 +75,12 @@ _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
}
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
# if "HIP_VISIBLE_DEVICES" in os.environ:
# val = os.environ["HIP_VISIBLE_DEVICES"]
# if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
# assert val == cuda_val
# else:
# os.environ["CUDA_VISIBLE_DEVICES"] = val
if "HIP_VISIBLE_DEVICES" in os.environ:
val = os.environ["HIP_VISIBLE_DEVICES"]
if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
assert val == cuda_val
else:
os.environ["CUDA_VISIBLE_DEVICES"] = val
# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
......
......@@ -53,12 +53,10 @@ class TpuPlatform(Platform):
and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend)
if use_v1:
logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
else:
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend"
if not use_v1:
raise ValueError("TPU backend only supports V1.")
logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
@classmethod
def set_device(cls, device: torch.device) -> None:
......@@ -78,7 +76,7 @@ class TpuPlatform(Platform):
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return not envs.VLLM_USE_V1
return False
@classmethod
def get_punica_wrapper(cls) -> str:
......@@ -129,31 +127,19 @@ class TpuPlatform(Platform):
"Using bfloat16 instead.", model_config.dtype)
model_config.dtype = torch.bfloat16
if envs.VLLM_USE_V1:
from vllm.v1.attention.backends.pallas import (
PallasAttentionBackend)
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config) # type: ignore[assignment]
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config) # type: ignore[assignment]
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
if envs.VLLM_USE_V1:
raise NotImplementedError(
"Multi-step scheduling is not supported (and not "
"needed) on vLLM V1. Please launch without "
"--num-scheduler-steps.")
else:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.tpu_worker.TPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.tpu_worker.TPUWorker"
raise NotImplementedError(
"Multi-step scheduling is not supported (and not "
"needed) on vLLM V1. Please launch without "
"--num-scheduler-steps.")
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
......@@ -201,13 +187,9 @@ class TpuPlatform(Platform):
processed_inputs: ProcessorInputs,
) -> None:
"""Raises if this request is unsupported on this platform"""
if isinstance(params, SamplingParams):
if params.guided_decoding is not None and not envs.VLLM_USE_V1:
raise ValueError("Structured output is not supported on "
f"{cls.device_name} V0.")
if params.sampling_type == SamplingType.RANDOM_SEED:
raise ValueError(
"Torch XLA does not support per-request seed.")
if (isinstance(params, SamplingParams)
and params.sampling_type == SamplingType.RANDOM_SEED):
raise ValueError("Torch XLA does not support per-request seed.")
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
......
......@@ -40,12 +40,10 @@ class XPUPlatform(Platform):
if selected_backend is not None and selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
use_v1 = envs.VLLM_USE_V1
if use_v1:
logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
else:
logger.info("Using IPEX attention backend.")
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
if not use_v1:
raise ValueError("XPU backend only supports V1.")
logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
@classmethod
def set_device(cls, device: torch.device) -> None:
......@@ -90,10 +88,7 @@ class XPUPlatform(Platform):
model_config = vllm_config.model_config
# in V1(or with ipex chunked prefill) block_size is 64
if cache_config and cache_config.block_size is None:
if envs.VLLM_USE_V1:
cache_config.block_size = 64
else:
cache_config.block_size = 16
cache_config.block_size = 64
# FIXME: Temporarily forcing eager mode
# remove after t.compile support stabilizes.
......@@ -118,11 +113,7 @@ class XPUPlatform(Platform):
# check and update parallel config
parallel_config = vllm_config.parallel_config
if envs.VLLM_USE_V1:
parallel_config.worker_cls =\
"vllm.v1.worker.xpu_worker.XPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
if parallel_config.distributed_executor_backend is None:
if parallel_config.world_size > 1:
......
......@@ -18,6 +18,7 @@ logger = init_logger(__name__)
class Glm4MoeModelReasoningParser(ReasoningParser):
"""
Reasoning parser for the Glm4MoeModel model.
The Glm4MoeModel model uses <think>...</think> tokens to denote reasoning
text within its output. The model provides a strict switch to disable
reasoning output via the 'enable_thinking=False' parameter. This parser
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from array import array
from itertools import chain, count
from typing import Iterator, List, Optional, Tuple
import torch
from vllm import SamplingParams
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
ExecuteModelRequest, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
SeqId = int
TargetSeqId = int
TokenId = int
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Implements a speculative scorer that uses batch expansion to get
probabilities of speculative tokens according to the scoring model.
Batch expansion converts a list of sequences and multiple query positions
to a new batch of sequences, each with a single query position. This allows
for MQA-like scoring in speculative decoding without requiring an MQA
kernel.
It is strictly less efficient than MQA scoring.
It only supports scoring the top1 proposal tokens of the proposer, instead
of topk/tree.
"""
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
# TODO(cade) perform this on GPU to remove blocking call.
proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips = [
proposals for proposals in proposal_token_ids_list
if VLLM_INVALID_TOKEN_ID not in proposals
]
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list,
)
target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
if not non_spec_indices:
# All sequence groups in batch have spec decoding enabled
return self._contract_batch_all_spec(
target_sampler_output=target_sampler_output,
proposals=proposals,
)
else:
# Batch has a mix of spec decode enabled and disabled seq groups
return self._contract_batch(
execute_model_req.seq_group_metadata_list,
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
def _expand_batch(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids_list: List[List[TokenId]],
proposal_lens_list: List[int],
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
"""Given the input sequences and potentially multiple corresponding
proposal tokens, create a new batch where each sequence has a single
query token.
"""
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
(spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
split_batch_by_proposal_len(
seq_group_metadata_list, proposal_lens_list)
spec_expanded_seqs = self._create_scoring_model_input(
seq_group_metadata_list=spec_seqs,
proposal_token_ids=proposal_token_ids_list,
# NOTE: We determine the seq ids in the expanded batch using the
# full seq_group_metadata_list, instead of only spec_seqs.
target_seq_ids_iter=self._create_target_seq_id_iterator(
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
)
num_scoring_tokens = len(spec_expanded_seqs)
# Batch speculative and non-speculative (e.g. chunked prefill) requests
# but make sure order is prefill|decode due to backend requirement.
target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens)
def _contract_non_speculative(
self, scores: SpeculativeScores,
seq_group_metadata_list: List[SequenceGroupMetadata],
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
has_prompt_log: bool) -> SpeculativeScores:
"""
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
if not non_spec_indices:
return scores
if has_prompt_log:
# When prompt_logprobs is enabled, prefills yield output token
# (and respective prob) in the last entry (prompt|out):
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
# With chunked prefill, non-terminal chunks have -1 on each
# position: they're still picked, but they're discarded later.
seq_meta = seq_group_metadata_list
nospec_sizes = torch.tensor([
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
for i in non_spec_indices
])
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
else:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs = list(
range(len(non_spec_outputs.token_ids)))
scores.token_ids[non_spec_indices, :1] = \
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
scores.probs[non_spec_indices, :1, :] = \
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
scores.logprobs[non_spec_indices, :1, :] = \
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
if scores.hidden_states is not None:
assert non_spec_outputs.hidden_states is not None
scores.hidden_states[non_spec_indices, :1, :] = \
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
return scores
def _contract_batch(
self,
contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> SpeculativeScores:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
contracted_bs = len(contracted_seq_group_metadata_list)
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs,
non_spec_target_hidden_states) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
expanded_batch_size, k = proposals.proposal_token_ids.shape
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences, prefill chunks with no out tokens included
non_spec_expanded_bs = len(non_spec_indices)
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
target_probs = target_probs.reshape(*target_token_ids.shape,
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))
if target_sampler_output.hidden_states is not None:
all_hidden_states = target_hidden_states.new_zeros(
size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
else:
all_hidden_states = None
has_prompt_log = any((sg.sampling_params.prompt_logprobs
and sg.sampling_params.prompt_logprobs > 0)
for sg in contracted_seq_group_metadata_list)
# When prompt logprobs is enabled, lens of returned tensors go from
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
# We adjust stride accordingly to get the generated tokens and
# their probs, but pass on prompt_logprobs as is.
prompt_logprobs = None
if (not self._scorer_worker.model_runner.disable_logprobs\
and has_prompt_log):
prompt_logprobs = [
o.prompt_logprobs for o in target_sampler_output.outputs
]
elif not has_prompt_log:
# When prompt logprobs are not to be returned,
# we can ignore non-terminal chunks (no out token).
non_spec_indices = [
idx for idx in non_spec_indices
if contracted_seq_group_metadata_list[idx].do_sample
]
# "Contract" speculative.
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs
if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states
spec_scores = SpeculativeScores(probs=all_probs,
token_ids=all_tokens,
logprobs=all_logprobs,
hidden_states=all_hidden_states,
prompt_logprobs=prompt_logprobs)
non_spec_outputs = SpeculativeScores(
probs=non_spec_target_probs,
token_ids=non_spec_target_token_ids,
logprobs=non_spec_target_logprobs,
hidden_states=non_spec_target_hidden_states)
# Contract remaining nonspec entries based on non_spec_indices, if any.
return self._contract_non_speculative(
spec_scores, contracted_seq_group_metadata_list, non_spec_indices,
non_spec_outputs, has_prompt_log)
def _contract_batch_all_spec(
self,
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs, k = proposals.proposal_token_ids.shape
# Reshape tensors to original batch size
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
contracted_bs, k + 1)
target_probs = target_sampler_output.sampled_token_probs.reshape(
*target_token_ids.shape, self._vocab_size)
target_logprobs = target_sampler_output.logprobs.reshape(
target_probs.shape)
target_hidden_states = target_sampler_output.hidden_states
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])
return SpeculativeScores(probs=target_probs,
token_ids=target_token_ids,
logprobs=target_logprobs,
hidden_states=target_hidden_states,
prompt_logprobs=None)
def _create_scoring_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
"""Given the original input sequences and proposed tokens from the draft
model, create a list of target sequences that can be used for scoring.
target_seq_ids_iter provides sequence ids for the expanded batch,
fulfilling the requirement that no seq id in the expanded batch is equal
to the seq id in the original batch.
"""
if not seq_group_metadata_list:
return []
target_seq_group_metadata = list(
chain.from_iterable(
self._create_target_seq_group_metadata(
seq_group_metadata,
proposal_token_ids,
i,
target_seq_ids_iter,
) for i, seq_group_metadata in enumerate(
seq_group_metadata_list)))
return target_seq_group_metadata
def _create_target_seq_group_metadata(
self,
input_seq_group_metadata: SequenceGroupMetadata,
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
batch_index: int,
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
"""Given an input sequence group metadata and a list of draft tokens,
create a list of target SequenceGroupMetadata, one for each
token id that needs to be scored.
Naive speculative decoding requires K target model scores, one for each
draft model token. However one can add a bonus token such that if each
token is accepted, then a final token may be sampled from the model.
This function creates K+1 target SequenceGroupMetadata to take
advantage of the bonus token.
"""
assert len(input_seq_group_metadata.seq_data) == 1, (
"Beam search "
"not supported in speculative decoding")
input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
token_ids_to_score = self._get_token_ids_to_score(
proposal_token_ids[batch_index])
sampling_params = input_seq_group_metadata.sampling_params
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
for i, token_ids in enumerate(token_ids_to_score):
target_seq_group_metadata_list.append(
self._create_single_target_seq_group_metadata(
input_seq_group_metadata,
input_seq_id,
next(target_seq_ids_iter),
token_ids,
sampling_params=sampling_params,
))
return target_seq_group_metadata_list
@staticmethod
def _create_single_target_seq_group_metadata(
seq_group_metadata: SequenceGroupMetadata,
seq_id: SeqId,
target_seq_id: TargetSeqId,
token_ids: List[TokenId],
sampling_params: SamplingParams,
) -> SequenceGroupMetadata:
"""Create a single target SequenceGroupMetadata.
Args:
seq_group_metadata: The metadata for the input sequence.
seq_id: The input sequence ID.
target_seq_id: The corresponding target sequence ID.
token_ids: The list of token ids that are to be appended to the
input sequence.
"""
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_token_ids = seq_data.prompt_token_ids_array
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
mrope_position_delta = seq_data.mrope_position_delta
new_seq_data_dict = {
target_seq_id:
SequenceData(
prompt_token_ids,
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
new_output_token_ids),
),
}
# This is a hack. Technically, spec decoding should compute
# num_lookahead slots at one shot, but instead, it expands the batch
# and evaluate one by one right now. context_len is seq_len - 1 because
# the kv cache is filled by a previous batch in the batch expansion.
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)
data.mrope_position_delta = mrope_position_delta
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data=new_seq_data_dict,
sampling_params=sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
token_chunk_size=1,
)
@staticmethod
def _split_scoring_output(
sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
"""Split the target model output into speculative and non-speculative
output.
"""
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
#
# First samples are non-speculative, latter samples are from speculative
# scoring (prefill|decode order).
split_sizes = (sampler_output.sampled_token_ids.numel() -
num_scoring_tokens, num_scoring_tokens)
(non_spec_probs,
spec_probs) = sampler_output.sampled_token_probs.split(split_sizes)
(non_spec_sampled_tokens, spec_sampled_tokens
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
(non_spec_logprobs,
spec_logprobs) = sampler_output.logprobs.split(split_sizes)
if sampler_output.hidden_states is not None:
(non_spec_hidden_states, spec_hidden_states
) = sampler_output.hidden_states.split(split_sizes)
else:
non_spec_hidden_states, spec_hidden_states = None, None
return (spec_sampled_tokens, spec_probs, spec_logprobs,
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
non_spec_logprobs, non_spec_hidden_states)
@staticmethod
def _create_target_seq_id_iterator(
seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
"""Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored.
This implementation increments a counter starting at 1 + max of all
provided input sequence ids.
"""
return count(start=max(seq_ids) + 1)
@staticmethod
def _get_token_ids_to_score(
full_spec_token_ids: List[TokenId] # shape: [k]
) -> List[List[TokenId]]:
"""Given an int tensor of proposal token ids, return a list of
token ids that should be scored.
Returns k+1 output lists. The additional one is used for generating the
bonus token.
Example:
Input: [0, 1, 2, 3] (k=4)
Output: (k+1 lists)
[]
[0]
[0, 1]
[0, 1, 2]
[0, 1, 2, 3]
"""
empty_token_ids: List[TokenId] = []
token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend(full_spec_token_ids[:i + 1]
for i in range(len(full_spec_token_ids)))
return token_ids_to_score
class BatchExpansionTreeStyleScorer(BatchExpansionTop1Scorer):
def _contract_batch(
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int], k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs,
non_spec_target_hidden_states) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
# of shape [batch_size * k] back to [batch_size, k].
expanded_batch_size, k = proposals.proposal_token_ids.shape
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
non_spec_expanded_bs = len(non_spec_target_token_ids)
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k)
target_probs = target_probs.reshape(*target_token_ids.shape,
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])
all_tokens = target_token_ids.new_full(size=(contracted_bs, k),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))
if target_sampler_output.hidden_states is not None:
all_hidden_states = target_hidden_states.new_zeros(
size=(contracted_bs, k, target_hidden_states.shape[-1]))
else:
all_hidden_states = None
if non_spec_indices:
all_tokens[non_spec_indices, :1] = \
non_spec_target_token_ids.unsqueeze(1)
all_probs[non_spec_indices, :1, :] = \
non_spec_target_probs.unsqueeze(1)
all_logprobs[non_spec_indices, :1, :] = \
non_spec_target_logprobs.unsqueeze(1)
if all_hidden_states is not None:
assert non_spec_target_hidden_states is not None
all_hidden_states[non_spec_indices, :1, :] = \
non_spec_target_hidden_states.unsqueeze(1)
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs
if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states
return all_tokens, all_probs, all_logprobs, all_hidden_states
def _contract_batch_all_spec(
self,
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs, k = proposals.proposal_token_ids.shape
# Reshape tensors to original batch size
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
contracted_bs, k)
target_probs = target_sampler_output.sampled_token_probs.reshape(
*target_token_ids.shape, self._vocab_size)
target_logprobs = target_sampler_output.logprobs.reshape(
target_probs.shape)
target_hidden_states = target_sampler_output.hidden_states
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])
return (target_token_ids, target_probs, target_logprobs,
target_hidden_states)
@staticmethod
def _create_single_target_seq_group_metadata(
seq_group_metadata: SequenceGroupMetadata,
seq_id: SeqId,
target_seq_id: TargetSeqId,
token_ids: List[TokenId],
sampling_params: SamplingParams,
) -> SequenceGroupMetadata:
"""Create a single target SequenceGroupMetadata.
Args:
seq_group_metadata: The metadata for the input sequence.
seq_id: The input sequence ID.
target_seq_id: The corresponding target sequence ID.
token_ids: The list of token ids that are to be appended to the
input sequence.
"""
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_token_ids = seq_data.prompt_token_ids_array
# first step need to ignore output token generated by prefill phase
if seq_data.get_first_step_flag():
new_output_token_ids = [*seq_data.get_output_token_ids()[:-1], *token_ids]
else:
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
new_seq_data_dict = {
target_seq_id:
SequenceData(
prompt_token_ids,
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
new_output_token_ids),
),
}
# This is a hack. Technically, spec decoding should compute
# num_lookahead slots at one shot, but instead, it expands the batch
# and evaluate one by one right now. context_len is seq_len - 1 because
# the kv cache is filled by a previous batch in the batch expansion.
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data=new_seq_data_dict,
sampling_params=sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
token_chunk_size=1,
)
def _get_token_ids_to_score(
self,
full_spec_token_ids: List[TokenId] # shape: [k]
) -> List[List[TokenId]]:
"""Given an int tensor of proposal token ids, return a list of
token ids that should be scored.
Returns k+1 output lists. The additional one is used for generating the
bonus token.
Example:
Input: [0, 1, 2, 3] (k=4)
Output: (k+1 lists)
[0]
[0, 1]
[0, 1, 2]
[0, 1, 2, 3]
"""
token_ids_to_score = []
token_ids_to_score.extend([
full_spec_token_ids[:i + 1]
for i in range(len(full_spec_token_ids))
])
return token_ids_to_score
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Set, Union
import torch
from vllm.sequence import ExecuteModelRequest, PromptLogprobs
from vllm.worker.worker_base import WorkerBase
@dataclass
class SpeculativeProposals:
"""Datastructure used to represent proposal tokens from some proposer. It
also tracks how many speculative tokens each sequence has.
"""
# Speculative proposal tokens.
proposal_token_ids: torch.Tensor
# Probabilities of the proposal tokens according to the proposer.
proposal_probs: torch.Tensor
# The valid length of each proposal; can be zero.
proposal_lens: torch.Tensor
# A flag to mark that there's no available proposals
no_proposals: bool = False
# The cart_candidates used in tree-style generation
cart_candidates: Optional[torch.Tensor] = None
# The cart_candidates used in tree-style generation
retrieve_indices: Optional[torch.Tensor] = None
# tree-style attention masks
tree_attn_masks: Optional[torch.Tensor] = None
# tree-style cartesian candidates
tree_position_ids: Optional[torch.Tensor] = None
def __repr__(self):
return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids}, "
f"proposal_probs={self.proposal_probs.shape}, "
f"proposal_lens={self.proposal_lens})")
@dataclass
class SpeculativeScores:
"""Datastructure used to represent the scores of speculative tokens
according to the scoring model.
"""
# Probabilities of the speculative tokens according to the scoring model.
probs: torch.Tensor
# Log-probabilities of the speculative tokens according to the scoring
# model. These values can be used to generate Logprob objects that are
# returned to the user.
logprobs: torch.Tensor
# Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding.
token_ids: torch.Tensor
# Optional last hidden states from the scoring model.
hidden_states: Optional[torch.Tensor] = None
# Optional lm_head logits from the scoring model.
logits: Optional[torch.Tensor] = None
# Scoring model may also return logprobs for prompt tokens
# for each request, when chunked prefill is enabled.
prompt_logprobs: Optional[List[PromptLogprobs]] = None
def __repr__(self):
return (f"SpeculativeScores("
f"probs={self.probs.shape}, "
f"token_ids={self.token_ids.shape})")
class SpeculativeProposer(ABC):
@abstractmethod
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
# If set, this contains all sequence IDs that were assigned
# bonus tokens in their last forward pass.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals:
raise NotImplementedError
class SpeculativeScorer(ABC):
def __init__(self, scorer_worker: WorkerBase,
device: Union[torch.device, str], vocab_size: int):
self._scorer_worker = scorer_worker
if isinstance(device, torch.device):
device = device.type
self._device = device
self._vocab_size = vocab_size
@abstractmethod
def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment