Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
......@@ -28,11 +28,12 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -46,7 +47,7 @@ class StablelmMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
......@@ -54,7 +55,7 @@ class StablelmMLP(nn.Module):
self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size, [config.intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=False)
......@@ -71,7 +72,7 @@ class StablelmAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
......@@ -109,11 +110,11 @@ class StablelmAttention(nn.Module):
self.total_num_heads,
self.total_num_key_value_heads,
self.qkv_bias,
linear_method=linear_method)
quant_config=quant_config)
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_ndims,
......@@ -145,11 +146,11 @@ class StablelmDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.self_attn = StablelmAttention(config)
self.mlp = StablelmMLP(config, linear_method)
self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
......@@ -187,14 +188,14 @@ class StableLMEpochModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
StablelmDecoderLayer(config, linear_method)
StablelmDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
norm_eps = getattr(config, "norm_eps",
......@@ -226,12 +227,12 @@ class StablelmForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = StableLMEpochModel(config, linear_method)
self.quant_config = quant_config
self.model = StableLMEpochModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -45,7 +46,7 @@ class Starcoder2Attention(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
......@@ -79,13 +80,13 @@ class Starcoder2Attention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=self.use_bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=self.use_bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -121,21 +122,20 @@ class Starcoder2MLP(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.c_fc = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=config.use_bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=config.use_bias,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
......@@ -150,12 +150,11 @@ class Starcoder2DecoderLayer(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config,
linear_method=linear_method)
self.mlp = Starcoder2MLP(config, linear_method=linear_method)
self.self_attn = Starcoder2Attention(config, quant_config=quant_config)
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
......@@ -192,7 +191,7 @@ class Starcoder2Model(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
......@@ -202,7 +201,7 @@ class Starcoder2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
Starcoder2DecoderLayer(config, linear_method=linear_method)
Starcoder2DecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
......@@ -227,10 +226,10 @@ class Starcoder2ForCausalLM(nn.Module):
def __init__(self,
config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.model = Starcoder2Model(config, linear_method=linear_method)
self.model = Starcoder2Model(config, quant_config=quant_config)
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:
......
......@@ -31,11 +31,12 @@ from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -52,17 +53,17 @@ class XverseMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -85,7 +86,7 @@ class XverseAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
) -> None:
......@@ -112,13 +113,13 @@ class XverseAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
......@@ -154,7 +155,7 @@ class XverseDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -171,7 +172,7 @@ class XverseDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
quant_config=quant_config,
bias=getattr(config, "bias", False),
sliding_window=sliding_window,
)
......@@ -179,7 +180,7 @@ class XverseDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -220,7 +221,7 @@ class XverseModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
......@@ -236,7 +237,7 @@ class XverseModel(nn.Module):
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
XverseDecoderLayer(config, linear_method)
XverseDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -294,13 +295,13 @@ class XverseForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config=None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = XverseModel(config, linear_method)
self.quant_config = quant_config
self.model = XverseModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -6,57 +6,284 @@ import torch
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
from vllm.utils import is_pin_memory_available
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
maybe_expand_dim)
_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
@dataclass
class SequenceGroupToSample:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Sequence ids for the sequence group in a previous step.
seq_ids: List[int]
sampling_params: SamplingParams
# seq_id -> sequence data.
seq_data: Dict[int, SequenceData]
# The length of the sequence (all tokens seen in the past + new token to
# compute attention) of the sequence group. None if it is in a decode
# stage.
seq_len: Optional[int]
# The length of new query tokens to compute in the current step. None if it
# is in a decode stage. The length of query_len <= seq_len if chunked
# prefill is enabled.
query_len: Optional[int]
# A random number generator for sampling.
generator: Optional[torch.Generator]
# True if the sequence group is in prefill stage. False if it is in a
# decode stage.
is_prompt: bool
# Query token indices from logits. to compute prompt logprob. Empty if
# prompt logprob is not required.
prompt_logprob_indices: List[int]
# Sample token indices from logits. Empty if sampling is not required.
sample_indices: List[int]
@property
def do_sample(self):
return len(self.sample_indices) > 0
def __post_init__(self):
if len(self.prompt_logprob_indices) > 0:
assert self.sampling_params.prompt_logprobs is not None
if self.is_prompt:
assert self.seq_len is not None
assert self.query_len is not None
class SamplingMetadata:
"""Metadata for input sequences. Used in sampler.
The usage is as follow;
```
hidden_states = execute_model(...)
logits = hidden_states[sampling_metadata.selected_token_indices]
sample(logits)
def sample(logits):
# Use categorized_sample_indices for sampling....
```
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
seq_groups: List of batched sequence groups.
selected_token_indices: (num_query_tokens_to_logprob). Indices to find
logits from the initial model output hidden states.
categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
Each token indices is 2D tensor of (num_indices, num_indices) where
the first item means the sample index within the returned logit
(before pruning padding), and the second item means the sample
index after pruning using selected_token_indices.
For example, if the returned logit is [1, 2, 3], and we select
[1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
"""
def __init__(
self,
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
seq_data: Optional[Dict[int, SequenceData]],
prompt_lens: Optional[List[int]],
seq_groups: List[SequenceGroupToSample],
selected_token_indices: torch.Tensor,
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
generators: Optional[List[torch.Generator]] = None,
perform_sampling: bool = True,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
num_prompts: int,
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.generators = generators
self.perform_sampling = perform_sampling
self.num_prompts = num_prompts
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
@staticmethod
def prepare(
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int],
query_lens: Optional[List[int]],
device: str,
pin_memory: bool,
) -> "SamplingMetadata":
(
seq_groups,
selected_token_indices,
categorized_sample_indices,
num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=device,
pin_memory=pin_memory)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=device,
pin_memory=pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
num_prompts=num_prompts,
)
return sampling_metadata
def __repr__(self) -> str:
return (
"SamplingMetadata("
f"seq_groups={self.seq_groups}, "
f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens}, "
f"selected_token_indices={self.selected_token_indices}, "
f"categorized_sample_indices={self.categorized_sample_indices}), "
f"perform_sampling={self.perform_sampling})")
f"categorized_sample_indices={self.categorized_sample_indices}), ")
def _prepare_seq_groups(
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int],
query_lens: Optional[List[int]],
device: str,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]:
"""Prepare sequence groups and indices for sampling.
Args:
seq_group_metadata_list: A list of sequence group to batch.
seq_lens: A list of sequence lens per sequence group.
Index of prompt len should match with seq_group_metadata_list.
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
`SequenceGroupToSample.generator`.
Returns:
seq_groups: A list of sequence group to sample.
selected_token_indices: See the definition from `SamplingMetadata`.
categorized_sample_indices: See the definition from `SamplingMetadata`.
num_prompts: Total number of prompts from `seq_group_metadata_list`.
"""
# Batched sequence groups for the current model forward stsep.
seq_groups: List[SequenceGroupToSample] = []
# A list of token indices to sample/compute logprob. It is used to
# prune the outcome logits from the model for the performance.
selected_token_indices: List[int] = []
# Used for selected_token_indices.
model_output_idx = 0
# Sampling type -> (
# indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits)
categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
# Index of logits to compute logprob. Logits include both prompt logprob
# and sample logprob indices.
logit_idx = 0
# Index to sample from a sample tensor. It is used by triton sample kernel.
# See `_sample_with_triton_kernel` for more details.
sample_idx = 0
# Total number of prompts from given sequence groups.
num_prompts = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
is_prompt = seq_group_metadata.is_prompt
generator: Optional[torch.Generator] = None
# If the current seq group is in decode stage, it is None.
seq_len: Optional[int] = None
query_len: Optional[int] = None
prompt_logprob_indices: List[int] = []
sample_indices: List[int] = []
do_sample = seq_group_metadata.do_sample
if seq_group_metadata.is_prompt:
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=device).manual_seed(sampling_params.seed)
num_prompts += 1
num_prefill_sample = len(seq_ids)
assert num_prefill_sample == 1
assert query_lens is not None and seq_lens is not None
query_len, seq_len = query_lens[i], seq_lens[i]
# If we need sampling, exclude num_prefill_sample tokens from
# prompt logprob.
prompt_logprob_len = (query_len - num_prefill_sample
if do_sample else query_len)
sample_len = num_prefill_sample if do_sample else 0
else:
# Decode
prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0
# Update indices to select from the model output.
"""
This blocks computes selected_token_indices which is used in the
following way.
hidden_states = model(...)
logits = hidden_states[selected_token_indices]
"""
if sampling_params.prompt_logprobs:
selected_token_indices.extend(
range(model_output_idx, model_output_idx + prompt_logprob_len))
model_output_idx += prompt_logprob_len
if do_sample:
selected_token_indices.extend(
range(model_output_idx, model_output_idx + sample_len))
model_output_idx += sample_len
# We now find indices for logprob computation and sampling.
"""
This block computes categorized_sample_indices which is used in the
following way.
hidden_states = model(...)
logits = hidden_states[selected_token_indices]
def sample(logits):
# Use categorized_sample_indices for sampling.
# prompt_logprob_indices to find prompt logprob indices.
# sample_indices to find sample indices.
"""
if sampling_params.prompt_logprobs is not None:
prompt_logprob_indices.extend(
range(logit_idx, logit_idx + prompt_logprob_len))
logit_idx += prompt_logprob_len
if do_sample:
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
categorized_sample_indices[sampling_params.sampling_type].extend(
list(
zip(range(logit_idx, logit_idx + sample_len),
range(sample_idx, sample_idx + sample_len))))
logit_idx += sample_len
sample_idx += sample_len
if sampling_params.seed is not None:
generator = seq_group_metadata.state.generator
seq_groups.append(
SequenceGroupToSample(
seq_ids=seq_ids,
sampling_params=sampling_params,
seq_data=seq_group_metadata.seq_data,
seq_len=seq_len,
query_len=query_len,
generator=generator,
is_prompt=is_prompt,
prompt_logprob_indices=list(prompt_logprob_indices),
sample_indices=list(sample_indices)))
return (seq_groups, selected_token_indices, categorized_sample_indices,
num_prompts)
@dataclass
......@@ -112,11 +339,10 @@ class SamplingTensors:
seeds_to_generate = (extra_seeds_to_generate +
get_num_triton_sampler_splits(vocab_size))
sample_indices_start_idx = 0
assert sampling_metadata.seq_groups is not None
assert sampling_metadata.seq_data is not None
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
temperature = sampling_params.temperature
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
......@@ -145,45 +371,46 @@ class SamplingTensors:
or abs(r - 1.0) >= _SAMPLING_EPS):
do_penalties = True
if (i < sampling_metadata.num_prompts
is_prompt = seq_group.is_prompt
if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get
# their logprobs
assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1)
min_ps += [min_p] * (prompt_len - 1)
presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1)
repetition_penalties += [1] * (prompt_len - 1)
prompt_tokens.extend([] for _ in range(prompt_len - 1))
output_tokens.extend([] for _ in range(prompt_len - 1))
for seq_id in seq_ids:
seq_data = sampling_metadata.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
temperatures += [temperature] * len(seq_ids)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
min_ps += [min_p] * len(seq_ids)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
is_prompt = i < sampling_metadata.num_prompts
query_len = seq_group.query_len
assert query_len is not None
prefill_len = len(seq_group.prompt_logprob_indices)
temperatures += [temperature] * prefill_len
top_ps += [top_p] * prefill_len
top_ks += [top_k] * prefill_len
min_ps += [min_p] * prefill_len
presence_penalties += [0] * prefill_len
frequency_penalties += [0] * prefill_len
repetition_penalties += [1] * prefill_len
prompt_tokens.extend([] for _ in range(prefill_len))
output_tokens.extend([] for _ in range(prefill_len))
if seq_group.do_sample:
sample_lens = len(seq_group.sample_indices)
assert sample_lens == len(seq_ids)
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
temperatures += [temperature] * len(seq_ids)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
min_ps += [min_p] * len(seq_ids)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i]
query_len = seq_group.query_len
assert query_len is not None
if sampling_params.prompt_logprobs is not None:
# NOTE: the sampling position is the last token
# in the prompt
sample_indices_start_idx += prompt_len - 1
for seq_id in seq_ids:
seq_data = sampling_metadata.seq_data[seq_id]
seq_data = seq_group.seq_data[seq_id]
extra_entropy = extra_entropy or ()
seq_seeds = cls._get_sequence_seeds(
seed,
......@@ -193,8 +420,7 @@ class SamplingTensors:
seeds_to_generate=seeds_to_generate,
is_greedy=is_greedy)
sampling_seeds.append(seq_seeds)
sample_indices.append(sample_indices_start_idx)
sample_indices_start_idx += 1
sample_indices.extend(seq_group.sample_indices)
sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,
......@@ -217,12 +443,14 @@ class SamplingTensors:
# Note that the performance will be very bad without
# pinned memory.
pin_memory = is_pin_memory_available()
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
default=0)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
for tokens in prompt_tokens
]
output_max_len = max(len(tokens) for tokens in output_tokens)
output_max_len = max([len(tokens) for tokens in output_tokens],
default=0)
output_padded_tokens = [
tokens + [vocab_size] * (output_max_len - len(tokens))
for tokens in output_tokens
......
......@@ -139,7 +139,10 @@ class SamplingParams:
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.seed = seed
if seed == -1:
self.seed = None
else:
self.seed = seed
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
......@@ -185,8 +188,8 @@ class SamplingParams:
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
# injected by the engine
self.eos_token_id = None
# eos_token_id is added to this by the engine
self.all_stop_token_ids = set(self.stop_token_ids)
def _verify_args(self) -> None:
if self.n < 1:
......@@ -275,7 +278,8 @@ class SamplingParams:
self, generation_config: Dict[str, Any]) -> None:
"""Update if there are non-default values from generation_config"""
# Update eos_token_id for generation
if eos_ids := generation_config.get("eos_token_id"):
if (not self.ignore_eos) and (eos_ids :=
generation_config.get("eos_token_id")):
# it can be either int or list of int
if isinstance(eos_ids, int):
eos_ids = [eos_ids]
......
"""Sequence and its related classes."""
import copy
import enum
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock
......@@ -28,7 +28,10 @@ class Logprob:
decoded_token: Optional[str] = None
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
# {token_id -> logprob} for each sequence group.
SampleLogprobs = List[Dict[int, Logprob]]
......@@ -215,7 +218,7 @@ class Sequence:
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.data = SequenceData(prompt_token_ids)
self.data: SequenceData = SequenceData(prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
......@@ -439,15 +442,27 @@ class SequenceGroup:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
def get_last_latency(self, now: float) -> float:
"""Gets last token latency for Request level timings."""
def get_last_latency(self, now: float) -> Optional[float]:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
if self.is_prefill():
raise ValueError(
"seq_group.get_last_latency() should not be called "
"if the seq_group is in prefill phase.")
# Otherwise return token latency.
latency = now - self.metrics.last_token_time
self.metrics.last_token_time = now
return latency
def maybe_set_first_token_time(self, time: float) -> None:
"""Sets the first token time for Request level timings."""
if self.metrics.first_token_time is None:
# Note: in a case where a sequence_group is swapped and
# recomputed, the time between iterations is counted
# in TPOT, rather than recalculating TTFT (since from the )
# POV of the user, there is simply a long generation delay.
if (self.metrics.first_token_time is None
and self.get_seqs()[0].get_output_len() == 1):
self.metrics.first_token_time = time
def maybe_set_first_scheduled_time(self, time: float) -> None:
......@@ -559,10 +574,15 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
do_sample: True if sampling is required. Sampling is not required when
e.g., prefill is chunked, and the current iteration only computes
query tokens for prefill, we don't need sampling.
token_chunk_size: The number of tokens to be processed (per sequence).
None if chunking is not required.
state: Internal state tied to this sequence group.
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
"""
......@@ -573,6 +593,7 @@ class SequenceGroupMetadata:
seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]],
do_sample: bool = True,
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
......@@ -589,6 +610,7 @@ class SequenceGroupMetadata:
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self._token_chunk_size = token_chunk_size
self.do_sample = do_sample
if self._token_chunk_size is None:
if is_prompt:
......@@ -650,6 +672,7 @@ class SequenceGroupOutput:
prompt_logprobs: Optional[PromptLogprobs],
) -> None:
self.samples = samples
# Prompt logprob for each prompt query token.
self.prompt_logprobs = prompt_logprobs
def __repr__(self) -> str:
......@@ -677,6 +700,9 @@ class SamplerOutput:
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional["torch.Tensor"] = None
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional["torch.Tensor"] = None
......@@ -708,3 +734,33 @@ class SamplerOutput:
f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
@dataclass
class ExecuteModelRequest:
"""The model execution request."""
# The sequence group metadata list.
seq_group_metadata_list: List[SequenceGroupMetadata]
# Blocks to swap in. Dict of CPU -> GPU block number.
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
# Blocks to copy. Source to a list of dest blocks.
blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict)
# The number of slots for lookahead decoding.
num_lookahead_slots: int = 0
# The number of requests in the running queue.
running_queue_size: int = 0
def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
) -> "ExecuteModelRequest":
"""Clone the request with a new sequence group metadata list."""
return ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=self.blocks_to_swap_in.copy(),
blocks_to_swap_out=self.blocks_to_swap_out.copy(),
blocks_to_copy=self.blocks_to_copy.copy(),
num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size,
)
from itertools import chain, count
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Iterator, List, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
......@@ -40,11 +41,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
......@@ -57,11 +54,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
no speculation is produced for that sequence.
Args:
seq_group_metadata_list: The input sequence group metadata.
blocks_to_swap_in: This is passed to the worker during scoring.
blocks_to_swap_out: This is passed to the worker during scoring.
blocks_to_copy: This is passed to the worker during scoring.
k: The fixed proposal length.
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
......@@ -80,33 +73,31 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list,
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list,
)
target_sampler_output = self._scorer_worker.execute_model(
seq_group_metadata_list=target_seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list, ))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
all_tokens, all_probs = self._contract_batch(
contracted_bs=len(seq_group_metadata_list),
all_tokens, all_probs, spec_logprobs = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=k,
k=execute_model_req.num_lookahead_slots,
)
return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
logprobs=spec_logprobs,
)
def _expand_batch(
......@@ -148,12 +139,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens)
def _contract_batch(self, contracted_bs: int,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals,
num_scoring_tokens: int, non_spec_indices: List[int],
spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
def _contract_batch(
self, contracted_bs: int,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
......@@ -161,8 +152,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output(
(target_token_ids, target_probs, target_logprobs,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
......@@ -179,6 +171,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
spec_expanded_bs, k + 1)
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size)
target_logprobs = target_logprobs.squeeze().reshape(
spec_expanded_bs, k + 1, self._vocab_size)
all_tokens = torch.full(size=(contracted_bs, k + 1),
fill_value=-1,
......@@ -189,16 +183,26 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
self._vocab_size,
device=self._device,
dtype=torch.float32)
all_logprobs = torch.full(size=(
contracted_bs,
k + 1,
self._vocab_size,
),
fill_value=-float("inf"),
device=self._device,
dtype=torch.float32)
if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs
return all_tokens, all_probs
return all_tokens, all_probs, all_logprobs
def _create_scoring_model_input(
self,
......@@ -308,7 +312,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
"""Split the target model output into speculative and non-speculative
output.
"""
......@@ -328,21 +333,29 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
(
spec_logprobs,
non_spec_logprobs,
) = sampler_output.logprobs.split(split_sizes)
# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
target_token_ids, target_probs = sampler_output_to_torch(
[sampler_output])
sampler_output.logprobs = spec_logprobs
(target_token_ids, target_probs,
target_logprobs) = sampler_output_to_torch([sampler_output], True)
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
non_spec_target_token_ids, non_spec_target_probs = (
sampler_output_to_torch([sampler_output]))
return (target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs)
sampler_output.logprobs = non_spec_logprobs
(non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
True)
return (target_token_ids, target_probs, target_logprobs,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs)
def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
......
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
from vllm.sequence import SequenceGroupMetadata
from vllm.sequence import ExecuteModelRequest
@dataclass
......@@ -38,6 +37,11 @@ class SpeculativeScores:
# Probabilities of the speculative tokens according to the scoring model.
probs: torch.Tensor
# Log-probabilities of the speculative tokens according to the scoring
# model. These values can be used to generate Logprob objects that are
# returned to the user.
logprobs: torch.Tensor
# Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding.
token_ids: torch.Tensor
......@@ -53,11 +57,7 @@ class SpeculativeProposer(ABC):
@abstractmethod
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
raise NotImplementedError
......@@ -67,11 +67,7 @@ class SpeculativeScorer(ABC):
@abstractmethod
def score_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
raise NotImplementedError
import copy
from typing import Dict, List, Optional, Tuple
from typing import List, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
......@@ -26,50 +26,53 @@ class MultiStepWorker(Worker):
super().__init__(*args, **kwargs)
# Lazy initialization list.
self._proposer: DraftModelTop1Proposer
self._proposer: Top1Proposer
def init_device(self):
super().init_device()
self._proposer = DraftModelTop1Proposer(
self._proposer = Top1Proposer(
self,
self.device,
self.max_model_len,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
def set_include_gpu_probs_tensor(self):
# Need include_gpu_probs_tensor for multi_step_worker
self.model_runner.model.sampler.include_gpu_probs_tensor = True
@torch.inference_mode()
def execute_model_multi_step(
def sampler_output(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_steps: int,
) -> List[SamplerOutput]:
"""Run the model forward pass num_steps times. Returns the list of
sampler output, one per model forward pass.
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
blocks_to_swap_out, blocks_to_copy)
self._raise_if_unsupported(execute_model_req)
# Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects.
copied_seq_group_metadata_list = self._shallow_copy_inputs(
seq_group_metadata_list)
execute_model_req.seq_group_metadata_list)
copied_execute_model_req = execute_model_req.clone(
copied_seq_group_metadata_list)
# Assert enough KV space for num_steps tokens per sequence.
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
# Assert enough KV space for sample_len tokens per sequence.
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
sample_len)
# Run model num_steps times.
# Run model sample_len times.
model_outputs = []
for _ in range(num_steps):
for _ in range(sample_len):
model_output = super().execute_model(
seq_group_metadata_list=copied_seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
execute_model_req=copied_execute_model_req)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
......@@ -78,27 +81,17 @@ class MultiStepWorker(Worker):
copied_seq_group_metadata_list)
model_outputs.append(model_output)
return model_outputs
return model_outputs, True
def get_spec_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
max_proposal_len,
)
return self._proposer.get_proposals(execute_model_req)
def _append_new_tokens(
self, model_output: SamplerOutput,
......@@ -189,188 +182,22 @@ class MultiStepWorker(Worker):
def _raise_if_unsupported(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
execute_model_req: ExecuteModelRequest,
) -> None:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"MultiStepWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in seq_group_metadata_list):
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")
class DraftModelTop1Proposer(SpeculativeProposer):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def __init__(
self,
draft_worker: MultiStepWorker,
device: str,
max_model_len: int,
vocab_size: int,
):
self._draft_worker = draft_worker
self._device = device
self._max_model_len = max_model_len
self._vocab_size = vocab_size
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
# Split speculative- and non-speculative- sequences.
(proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices) = self._split_by_max_model_len(
seq_group_metadata_list, max_proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
maybe_sampler_output = self._draft_worker.execute_model_multi_step(
seq_group_metadata_list=nonzero_proposal_len_seqs,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_steps=max_proposal_len,
)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
max_proposal_len=max_proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
)
proposals = SpeculativeProposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
return proposals
def _split_by_max_model_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
max_proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length.
"""
proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
if seq_len + max_proposal_len < self._max_model_len:
proposal_lens.append(max_proposal_len)
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
else:
proposal_lens.append(0)
return (proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices)
def _merge_outputs(
self,
batch_size: int,
max_proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.full(size=(
batch_size,
max_proposal_len,
),
fill_value=-1,
dtype=torch.long,
device=self._device)
proposal_probs = torch.zeros(batch_size,
max_proposal_len,
self._vocab_size,
dtype=torch.float32,
device=self._device)
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full(size=(batch_size,
*proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(batch_size,
*proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = (entire_proposal_tokens,
entire_proposal_probs)
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
return proposal_tokens, proposal_probs, proposal_lens_tensor
from typing import List, Optional, Tuple
import torch
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NGramWorker(LoraNotSupportedWorkerBase):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding,
and in future we may also do RAG type drafter and other scenerios
which don't rely on LLM model to give proposals.
"""
def __init__(self, *args, **kwargs):
# Get local_rank/vocab_size from kwargs attribute
self.local_rank = kwargs["local_rank"]
self.vocab_size = kwargs["model_config"].get_vocab_size()
# Lazy initialization list.
self._proposer: Top1Proposer
def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
ngram_prompt_lookup_max: int):
# Search valid candidate window between
# ngram_prompt_lookup_min/ngram_prompt_lookup_max
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
def init_device(self):
self.device = torch.device(f"cuda:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None
# Current only support Top1Proposer
self._proposer = Top1Proposer(
self,
device=self.device,
vocab_size=self.vocab_size,
)
def set_include_gpu_probs_tensor(self):
# NGram don't need gpu sampler
pass
def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
"""NGram doesn't depend on model execution, just pass this function"""
pass
def determine_num_available_blocks(self) -> None:
"""NGram doesn't depend on model execution, no need to check blocks"""
pass
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""As there is no cache need to handle, just pass this function"""
pass
def get_cache_block_size_bytes(self):
"""Return the size of a cache block in bytes."""
return 0
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[Optional[List[SamplerOutput]], bool]:
"""NGram match algo to pick proposal candidate. Returns the list of
sampler output, one per SequenceGroupMetadata.
For ngram worker, we already done needed transposed internal, so the
indicator pass to sampler_output_to_torch shall be False.
"""
self._raise_if_unsupported(execute_model_req)
arr = []
has_spec_out = False
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
seq_data = next(iter(seq_group_metadata.seq_data.values()))
input_ids = torch.as_tensor(seq_data.get_token_ids(),
dtype=torch.long,
device=self.device)
input_length = seq_data.get_len()
for ngram_size in range(
min(self.ngram_prompt_lookup_max, input_length - 1),
self.ngram_prompt_lookup_min,
-1,
):
ngram_tensor = input_ids[-1 * ngram_size:]
windows = input_ids.unfold(dimension=0,
size=ngram_size,
step=1)
matches = (windows == ngram_tensor).all(dim=1)
match_indices = matches.nonzero(as_tuple=True)[0]
if match_indices.size()[0] > 1:
has_spec_out = True
res = seq_data.get_token_ids()
res = res[match_indices[0] + ngram_size:match_indices[0] +
ngram_size + sample_len]
res_len = len(res)
# pad 0 towards output as sample_len tokens required
res += [0] * (sample_len - res_len)
break
else:
# if no candidate found, fill with 0
res = [0] * sample_len
arr.append(res)
if not has_spec_out:
return None, False
outputs = []
token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device)
indices = token_ids.unsqueeze(2)
token_probs = torch.zeros(
(len(execute_model_req.seq_group_metadata_list), sample_len,
self.vocab_size),
dtype=torch.float32,
device=self.device,
)
token_probs.scatter_(2, indices, 1)
token_logprobs = torch.zeros(
(len(execute_model_req.seq_group_metadata_list), sample_len,
self.vocab_size),
dtype=torch.float32,
device=self.device,
)
for i in range(len(execute_model_req.seq_group_metadata_list)):
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_probs=token_probs[i],
logprobs=token_logprobs,
sampled_token_ids=token_ids[i],
))
return outputs, False
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_proposals(execute_model_req)
def _raise_if_unsupported(
self,
execute_model_req: ExecuteModelRequest,
) -> None:
"""NGramWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"NGramWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"NGramWorker does not support beam search.")
from functools import cached_property
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_seq_ids,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
......@@ -48,8 +51,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
@classmethod
def from_workers(cls, proposer_worker: MultiStepWorker,
scorer_worker: WorkerBase) -> "SpecDecodeWorker":
def create_worker(
cls,
scorer_worker: WorkerBase,
draft_worker_kwargs,
) -> "SpecDecodeWorker":
if "ngram_prompt_lookup_max" in draft_worker_kwargs:
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
else:
ngram_prompt_lookup_max = 0
if ngram_prompt_lookup_max > 0:
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
else:
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
......@@ -59,7 +81,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def __init__(
self,
proposer_worker: MultiStepWorker,
proposer_worker: WorkerBase,
scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
......@@ -134,8 +156,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True
(self.proposer_worker.model_runner.model.sampler.
include_gpu_probs_tensor) = True
self.proposer_worker.set_include_gpu_probs_tensor()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
......@@ -169,68 +190,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch.
"""
assert seq_group_metadata_list is not None, (
assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding "
"requires non-None seq_group_metadata_list")
logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}")
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
return self._run_no_spec(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return self._run_speculative_decoding_step(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
k=num_lookahead_slots,
)
if execute_model_req.num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req)
return self._run_speculative_decoding_step(execute_model_req)
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
) -> List[SamplerOutput]:
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to the
proposer and scorer model so that the KV cache is consistent between the
two.
"""
logger.info("run proposer worker no spec")
#logger.info("run proposer worker no spec")
self.proposer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
self.proposer_worker.execute_model(execute_model_req)
logger.info("run target worker no spec")
sampler_output = self.scorer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
#logger.info("run target worker no spec")
sampler_output = self.scorer_worker.execute_model(execute_model_req)
assert len(sampler_output) == 1
sampler_output = sampler_output[0]
......@@ -238,17 +228,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# overhead when the engine runs in a different process than the workers.
sampler_output.probs = None
sampler_output.sampled_tokens = None
sampler_output.logprobs = None
return [sampler_output]
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
) -> List[SamplerOutput]:
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
......@@ -258,32 +244,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence.
"""
logger.info("get spec proposals")
#logger.info("get spec proposals")
# Generate proposals using draft worker.
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)
logger.info("score proposals")
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
#logger.info("score proposals")
proposal_scores = self.scorer.score_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
k,
execute_model_req,
proposals,
)
logger.info("verify proposals")
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
proposal_scores, proposals, k)
#logger.info("verify proposals")
accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
logger.info("create output list")
return self._create_output_sampler_list(seq_group_metadata_list,
accepted_token_ids, k)
#logger.info("create output list")
return self._create_output_sampler_list(
execute_model_req.seq_group_metadata_list,
accepted_token_ids,
target_logprobs=target_logprobs,
k=execute_model_req.num_lookahead_slots)
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
......@@ -292,9 +273,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals,
max_proposal_len: int,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list = proposals.proposal_lens.tolist()
......@@ -341,17 +325,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
non_spec_token_ids[:, 1:] = -1
accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids])
logprobs = proposal_scores.logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids[original_indices] = accepted_token_ids.clone()
return accepted_token_ids
return accepted_token_ids, logprobs
def _create_output_sampler_list(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
k: int,
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
......@@ -359,30 +345,68 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
batch_size, num_steps = accepted_token_ids.shape
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)
# Get the top-k logprobs (which may or may not include the logprob of
# the accepted token).
(topk_logprobs_by_step,
topk_indices_by_step) = target_logprobs_by_step.topk(
k=self.scorer_worker.model_config.max_logprobs,
dim=-1,
)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids = get_all_seq_ids(seq_group_metadata_list)
# shape: [k+1, batch_size]
accepted_token_ids_by_step = accepted_token_ids.transpose(0,
1).tolist()
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize all tensors to CPU Python lists.
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step.tolist())
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
topk_indices_by_step = topk_indices_by_step.tolist()
# Construct the output on a per-step, per-sequence basis.
sampler_output_list = []
for token_ids_by_step in accepted_token_ids_by_step:
if all(token_id == -1 for token_id in token_ids_by_step):
for step_index in range(num_steps):
if all(token_id == -1
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
for token_id, seq_id in zip(token_ids_by_step, seq_ids):
for sequence_index in range(batch_size):
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index]
step_output_token_ids.append(
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq_id,
output_token=token_id,
# TODO Add verifier logprobs.
logprobs={token_id: Logprob(0.0)},
)
],
prompt_logprobs=None,
create_sequence_group_output(
token_id=accepted_token_ids_by_step[step_index]
[sequence_index],
token_id_logprob_rank=accepted_token_id_ranks_by_step[
step_index][sequence_index],
token_id_logprob=accepted_token_id_logprobs_by_step[
step_index][sequence_index],
seq_id=seq_ids[sequence_index],
topk_token_ids=topk_indices_by_step[step_index]
[sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs],
))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
......
from typing import List, Optional, Tuple
import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.worker.worker_base import WorkerBase
class Top1Proposer(SpeculativeProposer):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def __init__(
self,
worker: WorkerBase,
device: str,
vocab_size: int,
max_proposal_len: Optional[int] = None,
):
self._worker = worker
self._device = device
self.max_proposal_len = max_proposal_len
self._vocab_size = vocab_size
def get_proposals(
self,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
proposal_len = execute_model_req.num_lookahead_slots
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
(
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
# If sampler_transposed is true, then maybe_sampler_output's
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
nonzero_execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=nonzero_proposal_len_seqs,
num_lookahead_slots=proposal_len,
)
maybe_sampler_output, transposed = self._worker.sampler_output(
execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len,
)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
transposed = False
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
proposal_len=proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
sampler_transposed=transposed,
)
proposals = SpeculativeProposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
return proposals
def _split_by_max_model_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length."""
proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall no exccess this
# quota for nonzero_proposal
if (self.max_proposal_len is None
or seq_len + proposal_len < self.max_proposal_len):
proposal_lens.append(proposal_len)
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
else:
proposal_lens.append(0)
return (
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
)
def _merge_outputs(
self,
batch_size: int,
proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.full(
size=(
batch_size,
proposal_len,
),
fill_value=-1,
dtype=torch.long,
device=self._device,
)
proposal_probs = torch.zeros(
batch_size,
proposal_len,
self._vocab_size,
dtype=torch.float32,
device=self._device,
)
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
sampler_output, sampler_transposed)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full(
size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device,
)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(
batch_size,
*proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device,
)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = (
entire_proposal_tokens,
entire_proposal_probs,
)
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
return proposal_tokens, proposal_probs, proposal_lens_tensor
from contextlib import contextmanager
from itertools import chain
from typing import List, Tuple
from typing import Dict, List, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
SeqId = int
......@@ -21,6 +22,89 @@ def get_all_seq_ids(
]))
def get_all_num_logprobs(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
If the sampling params do not call for any logprobs, return 0 for that
sequence.
"""
all_num_logprobs = []
for seq_group_metadata in seq_group_metadata_list:
num_logprobs = seq_group_metadata.sampling_params.logprobs
if seq_group_metadata.sampling_params.logprobs is None:
num_logprobs = 0
all_num_logprobs.append(num_logprobs)
return all_num_logprobs
def get_sampled_token_logprobs(
# shape [num_steps, batch_size, vocab_size]
logprob_tensor: torch.Tensor,
sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
"""
num_steps, batch_size, vocab_size = logprob_tensor.shape
selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1),
torch.arange(batch_size),
sampled_token_ids, ]
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
-1, -1, vocab_size)
sampled_token_ids_ranks = (logprob_tensor >=
expanded_selected_logprobs).sum(-1)
return sampled_token_ids_ranks, selected_logprobs
def create_sequence_group_output(
token_id: int,
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[int],
topk_logprobs: List[float],
) -> SequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[int]): The list of top-k token ids.
topk_logprobs (List[float]): The list of top-k logprobs.
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs: Dict[int, Logprob] = {
token_id: Logprob(
logprob=token_id_logprob,
rank=token_id_logprob_rank,
),
}
logprobs.update({
topk_token_ids[topk_logprob_index]: Logprob(
logprob=topk_logprobs[topk_logprob_index],
rank=topk_logprob_index + 1,
)
for topk_logprob_index, _ in enumerate(topk_token_ids)
})
return SequenceGroupOutput(
samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs=logprobs)
],
# TODO add prompt logprobs support.
prompt_logprobs=None,
)
def split_batch_by_proposal_len(
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_lens: List[int], select_proposal_len_zero: bool
......@@ -49,10 +133,13 @@ def split_batch_by_proposal_len(
def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput],
) -> Tuple[torch.Tensor, torch.Tensor]:
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
we need do additional tensor transpose logic here.
Returns:
sampled_token_ids: torch.Tensor
shape: [batch_size, len(sampler_output_list)]
......@@ -68,7 +155,19 @@ def sampler_output_to_torch(
for sampler_output in sampler_output_list
],
dim=0,
).transpose(0, 1)
)
if sampler_transposed:
sampled_token_probs = sampled_token_probs.transpose(0, 1)
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs = torch.stack(
[sampler_output.logprobs for sampler_output in sampler_output_list],
dim=0,
)
if sampler_transposed:
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
# shape: [batch_size, num_sampler_output]
sampled_token_ids = torch.stack(
......@@ -77,9 +176,11 @@ def sampler_output_to_torch(
for sampler_output in sampler_output_list
],
dim=0,
).transpose(0, 1)
)
if sampler_transposed:
sampled_token_ids = sampled_token_ids.transpose(0, 1)
return sampled_token_ids, sampled_token_probs
return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
......
......@@ -72,9 +72,10 @@ class DbrxAttentionConfig(PretrainedConfig):
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
"You are using a model of type %s to instantiate a model of "
"type %s. This is not supported for all configurations of "
"models and can yield errors.",
config_dict["model_type"], cls.model_type)
return cls.from_dict(config_dict, **kwargs)
......@@ -151,9 +152,9 @@ class DbrxFFNConfig(PretrainedConfig):
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
"You are using a model of type %s to instantiate a model of "
"type %s. This is not supported for all "
"configurations of models and can yield errors.", config_dict["model_type"], cls.model_type)
return cls.from_dict(config_dict, **kwargs)
......
import os
from typing import Optional, Union
import huggingface_hub
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.config import VLLM_USE_MODELSCOPE
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
......@@ -58,11 +59,12 @@ def get_tokenizer(
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
revision: Optional[str] = None,
download_dir: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface/modelscope."""
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.
"""
if VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
......@@ -74,9 +76,10 @@ def get_tokenizer(
tokenizer_path = snapshot_download(
model_id=tokenizer_name,
cache_dir=download_dir,
revision=tokenizer_revision,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
# Ignore weights - we only need the tokenizer.
ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"])
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
tokenizer_name = tokenizer_path
if tokenizer_mode == "slow":
......@@ -90,7 +93,7 @@ def get_tokenizer(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
revision=revision,
**kwargs)
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
......@@ -114,7 +117,7 @@ def get_tokenizer(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
revision=revision,
**kwargs)
else:
raise e
......@@ -137,9 +140,8 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(
f"No tokenizer found in {lora_request.lora_local_path}, "
"using base model tokenizer instead. "
f"(Exception: {str(e)})")
"No tokenizer found in %s, using base model tokenizer instead. "
"(Exception: %s)", lora_request.lora_local_path, e)
tokenizer = None
return tokenizer
......
from typing import Optional
from vllm.config import TokenizerPoolConfig
from vllm.engine.ray_utils import ray
from vllm.executor.ray_utils import ray
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
......
......@@ -6,7 +6,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.engine.ray_utils import ray
from vllm.executor.ray_utils import ray
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
......
......@@ -15,20 +15,22 @@ import psutil
import requests
import torch
_config_home = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
import vllm.envs as envs
_config_home = envs.VLLM_CONFIG_ROOT
_USAGE_STATS_JSON_PATH = os.path.join(_config_home, "vllm/usage_stats.json")
_USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home,
"vllm/do_not_track")
_USAGE_STATS_ENABLED = None
_USAGE_STATS_SERVER = os.environ.get("VLLM_USAGE_STATS_SERVER",
"https://stats.vllm.ai")
_USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER
def is_usage_stats_enabled():
"""Determine whether or not we can send usage stats to the server.
The logic is as follows:
- By default, it should be enabled.
- Two environment variables can disable it:
- Three environment variables can disable it:
- VLLM_DO_NOT_TRACK=1
- DO_NOT_TRACK=1
- VLLM_NO_USAGE_STATS=1
- A file in the home directory can disable it if it exists:
......@@ -36,8 +38,8 @@ def is_usage_stats_enabled():
"""
global _USAGE_STATS_ENABLED
if _USAGE_STATS_ENABLED is None:
do_not_track = os.environ.get("DO_NOT_TRACK", "0") == "1"
no_usage_stats = os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1"
do_not_track = envs.VLLM_DO_NOT_TRACK
no_usage_stats = envs.VLLM_NO_USAGE_STATS
do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH)
_USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats
......@@ -167,7 +169,7 @@ class UsageMessage:
# Metadata
self.log_time = _get_current_timestamp_ns()
self.source = os.environ.get("VLLM_USAGE_SOURCE", "production")
self.source = envs.VLLM_USAGE_SOURCE
data = vars(self)
if extra_kvs:
......
import asyncio
import datetime
import enum
import gc
import glob
import os
import socket
import subprocess
import tempfile
import threading
import uuid
import warnings
from collections import defaultdict
......@@ -18,7 +21,8 @@ import psutil
import torch
from packaging.version import Version, parse
from vllm.logger import init_logger
import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
T = TypeVar("T")
logger = init_logger(__name__)
......@@ -171,7 +175,7 @@ def get_vllm_instance_id():
Instance id represents an instance of the VLLM. All processes in the same
instance should have the same instance id.
"""
return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}")
return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}"
@lru_cache(maxsize=None)
......@@ -222,18 +226,25 @@ def merge_async_iterators(
]
async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
if isinstance(item, Exception):
raise item
yield item
try:
while not all(finished) or not queue.empty():
item = await queue.get()
if isinstance(item, Exception):
raise item
yield item
except (Exception, asyncio.CancelledError) as e:
for task in _tasks:
# NOTE: Pass the error msg in cancel()
# when only Python 3.9+ is supported.
task.cancel()
raise e
await asyncio.gather(*_tasks)
return consumer()
def get_ip() -> str:
host_ip = os.environ.get("HOST_IP")
host_ip = envs.VLLM_HOST_IP
if host_ip:
return host_ip
......@@ -259,7 +270,8 @@ def get_ip() -> str:
warnings.warn(
"Failed to get the IP address, using 0.0.0.0 by default."
"The value can be set by the environment variable HOST_IP.",
"The value can be set by the environment variable"
" VLLM_HOST_IP or HOST_IP.",
stacklevel=2)
return "0.0.0.0"
......@@ -286,8 +298,9 @@ def get_open_port() -> int:
def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items():
if k in os.environ and os.environ[k] != v:
logger.warning(f"Overwriting environment variable {k} "
f"from '{os.environ[k]}' to '{v}'")
logger.warning(
"Overwriting environment variable %s "
"from '%s' to '%s'", k, os.environ[k], v)
os.environ[k] = v
......@@ -303,15 +316,16 @@ def cdiv(a: int, b: int) -> int:
@lru_cache(maxsize=None)
def get_nvcc_cuda_version() -> Optional[Version]:
cuda_home = os.environ.get('CUDA_HOME')
cuda_home = envs.CUDA_HOME
if not cuda_home:
cuda_home = '/usr/local/cuda'
if os.path.isfile(cuda_home + '/bin/nvcc'):
logger.info(f'CUDA_HOME is not found in the environment. '
f'Using {cuda_home} as CUDA_HOME.')
logger.info(
'CUDA_HOME is not found in the environment. '
'Using %s as CUDA_HOME.', cuda_home)
else:
logger.warning(
f'Not found nvcc in {cuda_home}. Skip cuda version check!')
logger.warning('Not found nvcc in %s. Skip cuda version check!',
cuda_home)
return None
nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
universal_newlines=True)
......@@ -341,21 +355,9 @@ def _generate_random_fp8(
del tensor_tmp
def create_kv_caches_with_random(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def get_kv_cache_torch_dtype(
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
if isinstance(cache_dtype, str):
if cache_dtype == "auto":
if isinstance(model_dtype, str):
......@@ -374,6 +376,55 @@ def create_kv_caches_with_random(
torch_dtype = cache_dtype
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
return torch_dtype
def create_kv_caches_with_random_flash(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
assert cache_dtype != "fp8"
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
scale = head_size**-0.5
key_caches, value_caches = [], []
for _ in range(num_layers):
key_value_cache = torch.empty(size=key_value_cache_shape,
dtype=torch_dtype,
device=device)
key_value_cache.uniform_(-scale, scale)
key_caches.append(key_value_cache[:, 0])
value_caches.append(key_value_cache[:, 1])
return key_caches, value_caches
def create_kv_caches_with_random(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
......@@ -569,7 +620,7 @@ def find_library(lib_name: str) -> str:
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
# `LD_LIBRARY_PATH` searches the library in the user-defined paths
env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
env_ld_library_path = envs.LD_LIBRARY_PATH
if not locs and env_ld_library_path:
locs = [
os.path.join(dir, lib_name)
......@@ -582,22 +633,23 @@ def find_library(lib_name: str) -> str:
def find_nccl_library():
so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")
so_file = envs.VLLM_NCCL_SO_PATH
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
# check if we have vllm-managed nccl
vllm_nccl_path = None
if torch.version.cuda is not None:
cuda_major = torch.version.cuda.split(".")[0]
path = os.path.expanduser(
f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*")
f"{VLLM_CONFIG_ROOT}/vllm/nccl/cu{cuda_major}/libnccl.so.*")
files = glob.glob(path)
vllm_nccl_path = files[0] if files else None
# manually load the nccl library
if so_file:
logger.info(
f"Found nccl from environment variable VLLM_NCCL_SO_PATH={so_file}"
)
"Found nccl from environment variable VLLM_NCCL_SO_PATH=%s",
so_file)
else:
if torch.version.cuda is not None:
so_file = vllm_nccl_path or find_library("libnccl.so.2")
......@@ -605,5 +657,21 @@ def find_nccl_library():
so_file = find_library("librccl.so.1")
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Found nccl from library {so_file}")
logger.info("Found nccl from library %s", so_file)
return so_file
def enable_trace_function_call_for_thread() -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""
if envs.VLLM_TRACE_FUNCTION:
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
......@@ -10,9 +10,8 @@ from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad, maybe_expand_dim
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
logger = init_logger(__name__)
......@@ -38,6 +37,8 @@ class CPUModelRunner:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
# Currently, CPU worker doesn't support chunked prefill.
assert self.scheduler_config.chunked_prefill_enabled is False
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.load_config = load_config
......@@ -79,7 +80,7 @@ class CPUModelRunner:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
prompt_lens: List[int] = []
seq_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list:
......@@ -91,15 +92,15 @@ class CPUModelRunner:
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
computed_len = seq_data.get_num_computed_tokens()
prompt_len = len(prompt_tokens)
seq_len = len(prompt_tokens)
prompt_lens.append(prompt_len) # Prompt token num
seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prompt_len)))
input_positions.extend(list(range(computed_len, seq_len)))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
......@@ -108,15 +109,15 @@ class CPUModelRunner:
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, prompt_len - self.sliding_window)
start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, prompt_len):
for i in range(computed_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
......@@ -150,19 +151,19 @@ class CPUModelRunner:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
prompt_lens=prompt_lens,
num_prefills=len(prompt_lens),
seq_lens=seq_lens,
seq_lens_tensor=None,
max_seq_len=None,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
prefill_metadata=None,
decode_metadata=None,
max_context_len=None,
context_lens=None,
block_tables=torch.tensor([]),
slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input)
def _prepare_decode(
......@@ -173,7 +174,7 @@ class CPUModelRunner:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
context_lens: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
......@@ -191,9 +192,9 @@ class CPUModelRunner:
position = seq_len - 1
input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min(
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
context_lens.append(context_len)
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
......@@ -207,7 +208,7 @@ class CPUModelRunner:
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
max_context_len = max(context_lens)
max_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
......@@ -218,9 +219,9 @@ class CPUModelRunner:
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
max_block_table_len = max(
len(block_table) for block_table in block_tables)
......@@ -235,14 +236,14 @@ class CPUModelRunner:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_seq_len=max_seq_len,
num_prefill_tokens=0,
num_decode_tokens=len(input_tokens),
max_context_len=max_context_len,
num_prefills=0,
prefill_metadata=None,
decode_metadata=None,
context_lens=context_lens,
block_tables=block_tables,
kv_cache_dtype=self.kv_cache_dtype,
)
......@@ -252,99 +253,6 @@ class CPUModelRunner:
attn_metadata,
)
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
subquery_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx +
subquery_len - 1)
selected_token_start_idx += subquery_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx +
num_seqs)))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long)
categorized_sample_indices = {
t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
......@@ -357,15 +265,22 @@ class CPUModelRunner:
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, prompt_lens,
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens)
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
# Broadcast the metadata.
metadata_dict = {
"input_tokens": input_tokens,
......@@ -385,11 +300,10 @@ class CPUModelRunner:
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
prompt_lens=None,
seq_lens=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
generators=None,
perform_sampling=False,
)
return (input_tokens, input_positions, attn_metadata,
......@@ -421,7 +335,7 @@ class CPUModelRunner:
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
if not sampling_metadata.perform_sampling:
if not self.is_driver_worker:
return None
# Sample the next token.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment