Commit b9e12416 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.3

parents e5d707db e9d3aa04
"""This file is used for /tests and /benchmarks"""
import numpy
import torch
SUPPORTED_NUM_BITS = [4, 8]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
def get_pack_factor(num_bits):
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
assert q_w.shape == w_ref.shape
orig_device = q_w.device
k_size, _ = q_w.shape
g_idx = torch.zeros((k_size, ), dtype=torch.int32)
for i in range(k_size):
g_idx[i] = i // group_size
# Simulate act_order by doing a random permutation on K
rand_perm = torch.randperm(k_size)
g_idx = g_idx[rand_perm].contiguous()
q_w = q_w[rand_perm, :].contiguous()
w_ref = w_ref[rand_perm, :].contiguous()
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
g_idx.to(device=orig_device),
rand_perm.to(device=orig_device),
)
def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
act_order: bool):
orig_device = w.device
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
if group_size == -1:
group_size = size_k
assert group_size <= size_k
max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2
# Reshape to [groupsize, -1]
if group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
s *= 2 / max_q_val # 2 => symmetric
# Quantize
q_w = torch.round(w / s).int()
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s
# Restore original shapes
if group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous()
# Apply act_order
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
if act_order:
assert (
group_size < size_k
), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k)
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s.to(device=orig_device),
g_idx.to(device=orig_device),
rand_perm.to(device=orig_device),
)
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
orig_device = q_w.device
sort_indices = torch.argsort(g_idx).to(
dtype=torch.int32) # Sort based on g_idx
g_idx = g_idx[sort_indices].contiguous()
q_w = q_w[sort_indices, :].contiguous()
return (
q_w.to(device=orig_device),
g_idx.to(device=orig_device),
sort_indices.to(device=orig_device),
)
def gptq_pack(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_k % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[i::pack_factor, :] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
return q_res
...@@ -12,15 +12,21 @@ class RejectionSampler(nn.Module): ...@@ -12,15 +12,21 @@ class RejectionSampler(nn.Module):
https://arxiv.org/pdf/2302.01318.pdf. https://arxiv.org/pdf/2302.01318.pdf.
""" """
def __init__(self, strict_mode: bool = False): def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
"""Create a rejection sampler. """Create a rejection sampler.
Args: Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds during sampling. This catches correctness issues but adds
nontrivial latency. nontrivial latency.
""" """
super().__init__() super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are # NOTE: A "bonus token" is accepted iff all proposal tokens are
...@@ -116,6 +122,7 @@ class RejectionSampler(nn.Module): ...@@ -116,6 +122,7 @@ class RejectionSampler(nn.Module):
draft_token_ids, draft_token_ids,
bonus_token_ids, bonus_token_ids,
) )
return output_token_ids return output_token_ids
def _batch_modified_rejection_sampling( def _batch_modified_rejection_sampling(
...@@ -312,7 +319,8 @@ class RejectionSampler(nn.Module): ...@@ -312,7 +319,8 @@ class RejectionSampler(nn.Module):
# proposal methods that require KV cache. We can fix it by "prefilling" # proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix. # the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212 # https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens[:, -1] = -1 if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids. # Fill the recovered token ids.
output.mul_(~after_false_mask).add_( output.mul_(~after_false_mask).add_(
......
...@@ -53,6 +53,7 @@ class RotaryEmbedding(nn.Module): ...@@ -53,6 +53,7 @@ class RotaryEmbedding(nn.Module):
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype,
) -> None: ) -> None:
super().__init__() super().__init__()
self.head_size = head_size self.head_size = head_size
...@@ -60,9 +61,10 @@ class RotaryEmbedding(nn.Module): ...@@ -60,9 +61,10 @@ class RotaryEmbedding(nn.Module):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
self.is_neox_style = is_neox_style self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache() cache = self._compute_cos_sin_cache()
cache = cache.to(torch.get_default_dtype()) cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
...@@ -109,7 +111,7 @@ class RotaryEmbedding(nn.Module): ...@@ -109,7 +111,7 @@ class RotaryEmbedding(nn.Module):
key_pass = key[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:]
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device) positions.device, dtype=query.dtype)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets) cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions] if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1) cos, sin = cos_sin.chunk(2, dim=-1)
...@@ -143,7 +145,8 @@ class RotaryEmbedding(nn.Module): ...@@ -143,7 +145,8 @@ class RotaryEmbedding(nn.Module):
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache = self.cos_sin_cache.to(positions.device) self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding() # ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors. # are in-place operations that update the query and key tensors.
if offsets is not None: if offsets is not None:
...@@ -166,6 +169,29 @@ class RotaryEmbedding(nn.Module): ...@@ -166,6 +169,29 @@ class RotaryEmbedding(nn.Module):
class LinearScalingRotaryEmbedding(RotaryEmbedding): class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling. """RotaryEmbedding extended with linear scaling.
It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per lora, we can run multiple
lora in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times.
Exemplary for two scaling factors x=1, y and z with embeddings
[[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
[[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
[[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
we construct the cos/sin cache as follows:
[[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
...
[xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
We then use offsets to index into the cos/sin cache for
the respective scaling factors.
The offset to cache can be accessed via `scaling_factor_to_offset` API.
Credits to the Reddit user /u/kaiokendev Credits to the Reddit user /u/kaiokendev
""" """
...@@ -177,16 +203,22 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -177,16 +203,22 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
scaling_factors: Union[List[float], float], scaling_factors: Union[List[float], float],
dtype: torch.dtype,
) -> None: ) -> None:
if isinstance(scaling_factors, float): if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors] scaling_factors = [scaling_factors]
self.scaling_factors = scaling_factors self.scaling_factors: List[float] = scaling_factors # noqa
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style) is_neox_style, dtype)
# Lazy initialized.
self._scaling_factor_to_offset: Dict[float, int]
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base) inv_freq = self._compute_inv_freq(self.base)
cache_list = [] cache_list: List[torch.Tensor] = []
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets: List[int] = []
for scaling_factor in self.scaling_factors: for scaling_factor in self.scaling_factors:
# NOTE(woosuk): self.max_position_embeddings is the original # NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling. # maximum length before applying the rope scaling.
...@@ -200,9 +232,25 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -200,9 +232,25 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
if not cache_list:
offset = 0
else:
last_offset = offsets[-1]
next_max_len = cache_list[-1].shape[0]
offset = last_offset + next_max_len
offsets.append(offset)
cache_list.append(cache) cache_list.append(cache)
self._scaling_factor_to_offset = {
float(scaling_factor): offsets[i]
for i, scaling_factor in enumerate(self.scaling_factors)
}
assert len(self.scaling_factors) == len(offsets)
return torch.cat(cache_list, dim=0) return torch.cat(cache_list, dim=0)
@property
def scaling_factor_to_offset(self) -> Dict[float, int]:
return self._scaling_factor_to_offset
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling. """RotaryEmbedding extended with Dynamic NTK scaling.
...@@ -218,10 +266,11 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): ...@@ -218,10 +266,11 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
scaling_factor: float, scaling_factor: float,
dtype: torch.dtype,
) -> None: ) -> None:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style) is_neox_style, dtype)
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
# NOTE(woosuk): self.max_position_embeddings is the original # NOTE(woosuk): self.max_position_embeddings is the original
...@@ -298,6 +347,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -298,6 +347,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
scaling_factor: float, scaling_factor: float,
dtype: torch.dtype,
*, *,
extrapolation_factor: float = 1, extrapolation_factor: float = 1,
attn_factor: float = 1, attn_factor: float = 1,
...@@ -313,7 +363,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -313,7 +363,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self.mscale = float( self.mscale = float(
_yarn_get_mscale(self.scaling_factor) * attn_factor) _yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style) is_neox_style, dtype)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**( pos_freqs = self.base**(
...@@ -358,6 +408,7 @@ class Phi3SuScaledRotaryEmbedding(nn.Module): ...@@ -358,6 +408,7 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
original_max_position_embeddings: int, original_max_position_embeddings: int,
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype,
short_factor: List[float], short_factor: List[float],
long_factor: List[float], long_factor: List[float],
short_mscale: float = 1.1, short_mscale: float = 1.1,
...@@ -384,14 +435,14 @@ class Phi3SuScaledRotaryEmbedding(nn.Module): ...@@ -384,14 +435,14 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
short_cache = self._compute_cos_sin_cache( short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale) original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(torch.get_default_dtype()) short_cache = short_cache.to(dtype)
self.register_buffer("short_cos_sin_cache", self.register_buffer("short_cos_sin_cache",
short_cache, short_cache,
persistent=False) persistent=False)
long_cache = self._compute_cos_sin_cache(max_position_embeddings, long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale) long_factor, long_mscale)
long_cache = long_cache.to(torch.get_default_dtype()) long_cache = long_cache.to(dtype)
self.register_buffer("long_cos_sin_cache", self.register_buffer("long_cos_sin_cache",
long_cache, long_cache,
persistent=False) persistent=False)
...@@ -462,7 +513,10 @@ def get_rope( ...@@ -462,7 +513,10 @@ def get_rope(
base: int, base: int,
is_neox_style: bool = True, is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None: if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls # Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = { rope_scaling_tuple = {
...@@ -473,12 +527,12 @@ def get_rope( ...@@ -473,12 +527,12 @@ def get_rope(
else: else:
rope_scaling_args = None rope_scaling_args = None
key = (head_size, rotary_dim, max_position, base, is_neox_style, key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args) rope_scaling_args, dtype)
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
if rope_scaling is None: if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style) is_neox_style, dtype)
else: else:
scaling_type = rope_scaling["type"] scaling_type = rope_scaling["type"]
if scaling_type != "su": if scaling_type != "su":
...@@ -487,11 +541,11 @@ def get_rope( ...@@ -487,11 +541,11 @@ def get_rope(
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base, max_position, base,
is_neox_style, is_neox_style,
scaling_factor) scaling_factor, dtype)
elif scaling_type == "dynamic": elif scaling_type == "dynamic":
rotary_emb = DynamicNTKScalingRotaryEmbedding( rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor) scaling_factor, dtype)
elif scaling_type == "yarn": elif scaling_type == "yarn":
original_max_position = rope_scaling[ original_max_position = rope_scaling[
"original_max_position_embeddings"] "original_max_position_embeddings"]
...@@ -504,7 +558,7 @@ def get_rope( ...@@ -504,7 +558,7 @@ def get_rope(
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position, original_max_position,
base, is_neox_style, base, is_neox_style,
scaling_factor, scaling_factor, dtype,
**extra_kwargs) **extra_kwargs)
elif scaling_type == "su": elif scaling_type == "su":
short_factor = rope_scaling["short_factor"] short_factor = rope_scaling["short_factor"]
...@@ -518,7 +572,8 @@ def get_rope( ...@@ -518,7 +572,8 @@ def get_rope(
} }
rotary_emb = Phi3SuScaledRotaryEmbedding( rotary_emb = Phi3SuScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position, head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, short_factor, long_factor, **extra_kwargs) base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb _ROPE_DICT[key] = rotary_emb
......
...@@ -10,8 +10,9 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata, ...@@ -10,8 +10,9 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors, SamplingTensors,
SequenceGroupToSample) SequenceGroupToSample)
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupOutput, SequenceOutput) PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceOutput)
# (num_token_ids, num_parent_ids) per sequence group. # (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]] SampleResultType = List[Tuple[List[int], List[int]]]
...@@ -680,7 +681,9 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: ...@@ -680,7 +681,9 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
""" """
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
indices] indices]
return (x > vals[:, None]).long().sum(1).add_(1) result = (x > vals[:, None])
del vals
return result.sum(1).add_(1)
def _get_logprobs( def _get_logprobs(
...@@ -782,13 +785,14 @@ def _get_logprobs( ...@@ -782,13 +785,14 @@ def _get_logprobs(
top_logprobs, top_token_ids = torch.topk(logprobs, top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs, largest_num_logprobs,
dim=-1) dim=-1)
top_logprobs = top_logprobs.cpu()
top_token_ids = top_token_ids.cpu()
else: else:
top_logprobs, top_token_ids = None, None top_logprobs, top_token_ids = None, None
selected_logprobs = selected_logprobs.cpu() selected_logprobs = selected_logprobs.to('cpu')
ranks = ranks.cpu() ranks = ranks.to('cpu')
if top_logprobs is not None and top_token_ids is not None:
top_logprobs = top_logprobs.to('cpu')
top_token_ids = top_token_ids.to('cpu')
# Find prompt/sample logprobs. # Find prompt/sample logprobs.
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
...@@ -828,37 +832,48 @@ def _get_prompt_logprob_if_needed( ...@@ -828,37 +832,48 @@ def _get_prompt_logprob_if_needed(
# Find prompt logprobs # Find prompt logprobs
prompt_logprobs: Optional[PromptLogprobs] = None prompt_logprobs: Optional[PromptLogprobs] = None
if (is_prompt and sampling_params.prompt_logprobs is not None): if is_prompt and sampling_params.prompt_logprobs is not None:
prompt_logprobs = [] prompt_logprobs = []
num_logprobs = sampling_params.prompt_logprobs num_logprobs = sampling_params.prompt_logprobs
next_prompt_tokens = _get_next_prompt_tokens(seq_group) next_prompt_tokens = _get_next_prompt_tokens(seq_group)
for token_id in next_prompt_tokens: # Pre-select indexes and create a list. It is faster than calling .item
# repetitively.
selected_logprob_items = selected_logprobs[
selected_logprobs_idx:selected_logprobs_idx +
len(next_prompt_tokens)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_prompt_tokens)].tolist()
for idx, token_id in enumerate(next_prompt_tokens):
# Calculate the prompt logprob of the real prompt tokens. # Calculate the prompt logprob of the real prompt tokens.
# Use tuple here for performance (to use to_list()).
# {token_id: (logprob, rank_from_vocab)} # {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
token_id: (selected_logprobs[selected_logprobs_idx].item(), token_id: (selected_logprob_items[idx], rank_items[idx])
ranks[selected_logprobs_idx].item())
} }
# Add top K prompt logprobs along with its rank. # Add top K prompt logprobs along with its rank.
if num_logprobs > 0: if num_logprobs > 0:
prompt_logprobs_dict.update( top_ids = top_token_ids[
zip( top_logprob_idx, :num_logprobs].tolist()
top_token_ids[top_logprob_idx, :num_logprobs].tolist(), top_probs = top_logprobs[
zip( top_logprob_idx, :num_logprobs].tolist()
top_logprobs[ # Top K is already sorted by rank, so we can use 1 ~
top_logprob_idx, :num_logprobs].tolist(), # num_logprobs + 1 for rank.
# This is ranks. Since top_logprob is sorted, top_ranks = range(1, num_logprobs + 1)
# we can just use a range here. prompt_logprobs_dict.update({
range(1, num_logprobs + 1)))) top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(top_ids, top_probs,
top_ranks)
})
prompt_logprobs.append({ prompt_logprobs.append({
token_id: Logprob(*logprob_and_rank) token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in prompt_logprobs_dict.items() for token_id, logprob_and_rank in prompt_logprobs_dict.items()
}) })
# + 1 to go to the next prompt token. # + 1 to go to the next prompt token.
top_logprob_idx += 1 top_logprob_idx += 1
selected_logprobs_idx += 1
# + len(next_prompt_tokens) to go to the next prompt.
selected_logprobs_idx += len(next_prompt_tokens)
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
...@@ -874,47 +889,54 @@ def _get_sampled_logprob_if_needed( ...@@ -874,47 +889,54 @@ def _get_sampled_logprob_if_needed(
): ):
"""Compute the sample logprob if needed.""" """Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs num_logprobs = seq_group.sampling_params.logprobs or 0
if num_logprobs is None:
num_logprobs = 0
sampled_logprobs: SampleLogprobs = [] sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result next_token_ids, parent_seq_ids = sample_result
if seq_group.do_sample: if seq_group.do_sample:
assert len(next_token_ids) > 0 assert len(next_token_ids) > 0
for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids): # Pre-select items from tensor. tolist() is faster than repetitive
# Calculate the sample logprob of the real sampled tokens. # `.item()` calls.
# Use tuple here for performance (to use to_list()). selected_logprob_items = selected_logprobs[
# token_id: (logprob, rank_from_vocab) selected_logprobs_idx:selected_logprobs_idx +
sampled_logprobs_dict: Dict[int, Tuple[float, int]] = { len(next_token_ids)].tolist()
next_token_id: rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
(selected_logprobs[selected_logprobs_idx].item(), len(next_token_ids)].tolist()
ranks[selected_logprobs_idx].item()) for idx, (next_token_id,
parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)):
# Get the logprob of a sampled token.
sampled_logprobs_dict = {
next_token_id: (selected_logprob_items[idx], rank_items[idx])
} }
# +1 to go to the next sampled token. Note that # Get top K logprobs.
# selected_logprobs can contain duplicates unlike top_logprobs if num_logprobs > 0:
# when beam search is enabled. top_ids = top_token_ids[top_logprob_idx +
selected_logprobs_idx += 1 parent_id, :num_logprobs].tolist()
top_probs = top_logprobs[top_logprob_idx +
# Second, add top K logprobs along with its rank. parent_id, :num_logprobs].tolist()
if num_logprobs >= 0: # Top K is already sorted by rank, so we can use 1 ~
sampled_logprobs_dict.update( # num_logprobs + 1 for rank.
zip( top_ranks = range(1, num_logprobs + 1)
top_token_ids[top_logprob_idx + sampled_logprobs_dict.update({
parent_id, :num_logprobs].tolist(), top_id: (top_prob, rank)
zip( for top_id, top_prob, rank in zip(top_ids, top_probs,
top_logprobs[top_logprob_idx + top_ranks)
parent_id, :num_logprobs].tolist(), })
# This is rank. Since top_logprob is sorted, we
# can just use a range here.
range(1, num_logprobs + 1))))
sampled_logprobs.append({ sampled_logprobs.append({
token_id: Logprob(*logprob_and_rank) token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in for token_id, logprob_and_rank in
sampled_logprobs_dict.items() sampled_logprobs_dict.items()
}) })
# There are len(seq_ids) number of sampled tokens for the current
# sequence group in top_logprobs. Jump to the next seq_group. # NOTE: This part of code is not intuitive. `selected_logprobs` include
# logprobs for the current step, which has len(next_token_ids) tokens
# per sequence group. `logprobs` includes logprobs from the previous
# steps, which has len(seq_ids) tokens per sequence group.
# Iterate to the next sequence group in a batch.
selected_logprobs_idx += len(next_token_ids)
# Iterate to the next sequence group in a batch.
top_logprob_idx += len(seq_ids) top_logprob_idx += len(seq_ids)
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
...@@ -1000,7 +1022,7 @@ def _build_sampler_output( ...@@ -1000,7 +1022,7 @@ def _build_sampler_output(
seq_outputs.append( seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append( sampler_output.append(
SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs))
# If not specified, store None values in SamplerOutput. # If not specified, store None values in SamplerOutput.
if on_device_tensors is not None: if on_device_tensors is not None:
......
...@@ -2,26 +2,29 @@ from typing import Optional ...@@ -2,26 +2,29 @@ from typing import Optional
from torch import nn from torch import nn
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.model_executor.model_loader.loader import (BaseModelLoader, from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader) get_model_loader)
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture) get_architecture_class_name, get_model_architecture)
def get_model( def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
*, model_config: ModelConfig, load_config: LoadConfig, device_config: DeviceConfig, parallel_config: ParallelConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig,
scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: vision_language_config: Optional[VisionLanguageConfig],
cache_config: CacheConfig) -> nn.Module:
loader = get_model_loader(load_config) loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config, return loader.load_model(model_config=model_config,
device_config=device_config, device_config=device_config,
lora_config=lora_config, lora_config=lora_config,
vision_language_config=vision_language_config, vision_language_config=vision_language_config,
parallel_config=parallel_config, parallel_config=parallel_config,
scheduler_config=scheduler_config) scheduler_config=scheduler_config,
cache_config=cache_config)
__all__ = [ __all__ = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment