Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
...@@ -28,11 +28,12 @@ from transformers import PretrainedConfig ...@@ -28,11 +28,12 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -46,7 +47,7 @@ class StablelmMLP(nn.Module): ...@@ -46,7 +47,7 @@ class StablelmMLP(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -54,7 +55,7 @@ class StablelmMLP(nn.Module): ...@@ -54,7 +55,7 @@ class StablelmMLP(nn.Module):
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size, [config.intermediate_size] * 2, config.hidden_size, [config.intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(config.intermediate_size, self.down_proj = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
...@@ -71,7 +72,7 @@ class StablelmAttention(nn.Module): ...@@ -71,7 +72,7 @@ class StablelmAttention(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -109,11 +110,11 @@ class StablelmAttention(nn.Module): ...@@ -109,11 +110,11 @@ class StablelmAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_key_value_heads, self.total_num_key_value_heads,
self.qkv_bias, self.qkv_bias,
linear_method=linear_method) quant_config=quant_config)
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.rotary_ndims, rotary_dim=self.rotary_ndims,
...@@ -145,11 +146,11 @@ class StablelmDecoderLayer(nn.Module): ...@@ -145,11 +146,11 @@ class StablelmDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.self_attn = StablelmAttention(config) self.self_attn = StablelmAttention(config)
self.mlp = StablelmMLP(config, linear_method) self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps", norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05)) getattr(config, "layer_norm_eps", 1e-05))
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
...@@ -187,14 +188,14 @@ class StableLMEpochModel(nn.Module): ...@@ -187,14 +188,14 @@ class StableLMEpochModel(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
StablelmDecoderLayer(config, linear_method) StablelmDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
norm_eps = getattr(config, "norm_eps", norm_eps = getattr(config, "norm_eps",
...@@ -226,12 +227,12 @@ class StablelmForCausalLM(nn.Module): ...@@ -226,12 +227,12 @@ class StablelmForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = StableLMEpochModel(config, linear_method) self.model = StableLMEpochModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -45,7 +46,7 @@ class Starcoder2Attention(nn.Module): ...@@ -45,7 +46,7 @@ class Starcoder2Attention(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -79,13 +80,13 @@ class Starcoder2Attention(nn.Module): ...@@ -79,13 +80,13 @@ class Starcoder2Attention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=self.use_bias, bias=self.use_bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=self.use_bias, bias=self.use_bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -121,21 +122,20 @@ class Starcoder2MLP(nn.Module): ...@@ -121,21 +122,20 @@ class Starcoder2MLP(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.c_fc = ColumnParallelLinear( self.c_fc = ColumnParallelLinear(
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
bias=config.use_bias, bias=config.use_bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=config.use_bias, bias=config.use_bias,
linear_method=linear_method, quant_config=quant_config,
) )
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config, self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size) config.intermediate_size)
...@@ -150,12 +150,11 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -150,12 +150,11 @@ class Starcoder2DecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config, self.self_attn = Starcoder2Attention(config, quant_config=quant_config)
linear_method=linear_method) self.mlp = Starcoder2MLP(config, quant_config=quant_config)
self.mlp = Starcoder2MLP(config, linear_method=linear_method)
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon) eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
...@@ -192,7 +191,7 @@ class Starcoder2Model(nn.Module): ...@@ -192,7 +191,7 @@ class Starcoder2Model(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -202,7 +201,7 @@ class Starcoder2Model(nn.Module): ...@@ -202,7 +201,7 @@ class Starcoder2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Starcoder2DecoderLayer(config, linear_method=linear_method) Starcoder2DecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
...@@ -227,10 +226,10 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -227,10 +226,10 @@ class Starcoder2ForCausalLM(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = Starcoder2Model(config, linear_method=linear_method) self.model = Starcoder2Model(config, quant_config=quant_config)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings: if config.tie_word_embeddings:
......
...@@ -31,11 +31,12 @@ from vllm.config import LoRAConfig ...@@ -31,11 +31,12 @@ from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -52,17 +53,17 @@ class XverseMLP(nn.Module): ...@@ -52,17 +53,17 @@ class XverseMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -85,7 +86,7 @@ class XverseAttention(nn.Module): ...@@ -85,7 +86,7 @@ class XverseAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
) -> None: ) -> None:
...@@ -112,13 +113,13 @@ class XverseAttention(nn.Module): ...@@ -112,13 +113,13 @@ class XverseAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=bias, bias=bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=bias, bias=bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -154,7 +155,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -154,7 +155,7 @@ class XverseDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -171,7 +172,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -171,7 +172,7 @@ class XverseDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, quant_config=quant_config,
bias=getattr(config, "bias", False), bias=getattr(config, "bias", False),
sliding_window=sliding_window, sliding_window=sliding_window,
) )
...@@ -179,7 +180,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -179,7 +180,7 @@ class XverseDecoderLayer(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -220,7 +221,7 @@ class XverseModel(nn.Module): ...@@ -220,7 +221,7 @@ class XverseModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -236,7 +237,7 @@ class XverseModel(nn.Module): ...@@ -236,7 +237,7 @@ class XverseModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
XverseDecoderLayer(config, linear_method) XverseDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -294,13 +295,13 @@ class XverseForCausalLM(nn.Module): ...@@ -294,13 +295,13 @@ class XverseForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config=None, lora_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = XverseModel(config, linear_method) self.model = XverseModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -6,57 +6,284 @@ import torch ...@@ -6,57 +6,284 @@ import torch
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
maybe_expand_dim)
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558 _SEED_0_REPLACEMENT = 3403598558
@dataclass
class SequenceGroupToSample:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Sequence ids for the sequence group in a previous step.
seq_ids: List[int]
sampling_params: SamplingParams
# seq_id -> sequence data.
seq_data: Dict[int, SequenceData]
# The length of the sequence (all tokens seen in the past + new token to
# compute attention) of the sequence group. None if it is in a decode
# stage.
seq_len: Optional[int]
# The length of new query tokens to compute in the current step. None if it
# is in a decode stage. The length of query_len <= seq_len if chunked
# prefill is enabled.
query_len: Optional[int]
# A random number generator for sampling.
generator: Optional[torch.Generator]
# True if the sequence group is in prefill stage. False if it is in a
# decode stage.
is_prompt: bool
# Query token indices from logits. to compute prompt logprob. Empty if
# prompt logprob is not required.
prompt_logprob_indices: List[int]
# Sample token indices from logits. Empty if sampling is not required.
sample_indices: List[int]
@property
def do_sample(self):
return len(self.sample_indices) > 0
def __post_init__(self):
if len(self.prompt_logprob_indices) > 0:
assert self.sampling_params.prompt_logprobs is not None
if self.is_prompt:
assert self.seq_len is not None
assert self.query_len is not None
class SamplingMetadata: class SamplingMetadata:
"""Metadata for input sequences. Used in sampler. """Metadata for input sequences. Used in sampler.
The usage is as follow;
```
hidden_states = execute_model(...)
logits = hidden_states[sampling_metadata.selected_token_indices]
sample(logits)
def sample(logits):
# Use categorized_sample_indices for sampling....
```
Args: Args:
seq_groups: List of (seq_ids, sampling_params). seq_groups: List of batched sequence groups.
seq_data: Seq_id -> SequenceData. selected_token_indices: (num_query_tokens_to_logprob). Indices to find
prompt_lens: Lengths of prompts. logits from the initial model output hidden states.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indices to sample. categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling Each token indices is 2D tensor of (num_indices, num_indices) where
perform_sampling: Whether to perform sampling. This option is used to the first item means the sample index within the returned logit
make the sampling only happens in the driver worker, and disable (before pruning padding), and the second item means the sample
sampling in other worker processes. index after pruning using selected_token_indices.
For example, if the returned logit is [1, 2, 3], and we select
[1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
""" """
def __init__( def __init__(
self, self,
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]], seq_groups: List[SequenceGroupToSample],
seq_data: Optional[Dict[int, SequenceData]],
prompt_lens: Optional[List[int]],
selected_token_indices: torch.Tensor, selected_token_indices: torch.Tensor,
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]], categorized_sample_indices: Dict[SamplingType, torch.Tensor],
generators: Optional[List[torch.Generator]] = None, num_prompts: int,
perform_sampling: bool = True,
) -> None: ) -> None:
self.seq_groups = seq_groups self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices self.categorized_sample_indices = categorized_sample_indices
self.generators = generators self.num_prompts = num_prompts
self.perform_sampling = perform_sampling
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0 @staticmethod
def prepare(
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int],
query_lens: Optional[List[int]],
device: str,
pin_memory: bool,
) -> "SamplingMetadata":
(
seq_groups,
selected_token_indices,
categorized_sample_indices,
num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=device,
pin_memory=pin_memory)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=device,
pin_memory=pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
num_prompts=num_prompts,
)
return sampling_metadata
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
"SamplingMetadata(" "SamplingMetadata("
f"seq_groups={self.seq_groups}, " f"seq_groups={self.seq_groups}, "
f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens}, "
f"selected_token_indices={self.selected_token_indices}, " f"selected_token_indices={self.selected_token_indices}, "
f"categorized_sample_indices={self.categorized_sample_indices}), " f"categorized_sample_indices={self.categorized_sample_indices}), ")
f"perform_sampling={self.perform_sampling})")
def _prepare_seq_groups(
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int],
query_lens: Optional[List[int]],
device: str,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]:
"""Prepare sequence groups and indices for sampling.
Args:
seq_group_metadata_list: A list of sequence group to batch.
seq_lens: A list of sequence lens per sequence group.
Index of prompt len should match with seq_group_metadata_list.
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
`SequenceGroupToSample.generator`.
Returns:
seq_groups: A list of sequence group to sample.
selected_token_indices: See the definition from `SamplingMetadata`.
categorized_sample_indices: See the definition from `SamplingMetadata`.
num_prompts: Total number of prompts from `seq_group_metadata_list`.
"""
# Batched sequence groups for the current model forward stsep.
seq_groups: List[SequenceGroupToSample] = []
# A list of token indices to sample/compute logprob. It is used to
# prune the outcome logits from the model for the performance.
selected_token_indices: List[int] = []
# Used for selected_token_indices.
model_output_idx = 0
# Sampling type -> (
# indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits)
categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
# Index of logits to compute logprob. Logits include both prompt logprob
# and sample logprob indices.
logit_idx = 0
# Index to sample from a sample tensor. It is used by triton sample kernel.
# See `_sample_with_triton_kernel` for more details.
sample_idx = 0
# Total number of prompts from given sequence groups.
num_prompts = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
is_prompt = seq_group_metadata.is_prompt
generator: Optional[torch.Generator] = None
# If the current seq group is in decode stage, it is None.
seq_len: Optional[int] = None
query_len: Optional[int] = None
prompt_logprob_indices: List[int] = []
sample_indices: List[int] = []
do_sample = seq_group_metadata.do_sample
if seq_group_metadata.is_prompt:
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=device).manual_seed(sampling_params.seed)
num_prompts += 1
num_prefill_sample = len(seq_ids)
assert num_prefill_sample == 1
assert query_lens is not None and seq_lens is not None
query_len, seq_len = query_lens[i], seq_lens[i]
# If we need sampling, exclude num_prefill_sample tokens from
# prompt logprob.
prompt_logprob_len = (query_len - num_prefill_sample
if do_sample else query_len)
sample_len = num_prefill_sample if do_sample else 0
else:
# Decode
prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0
# Update indices to select from the model output.
"""
This blocks computes selected_token_indices which is used in the
following way.
hidden_states = model(...)
logits = hidden_states[selected_token_indices]
"""
if sampling_params.prompt_logprobs:
selected_token_indices.extend(
range(model_output_idx, model_output_idx + prompt_logprob_len))
model_output_idx += prompt_logprob_len
if do_sample:
selected_token_indices.extend(
range(model_output_idx, model_output_idx + sample_len))
model_output_idx += sample_len
# We now find indices for logprob computation and sampling.
"""
This block computes categorized_sample_indices which is used in the
following way.
hidden_states = model(...)
logits = hidden_states[selected_token_indices]
def sample(logits):
# Use categorized_sample_indices for sampling.
# prompt_logprob_indices to find prompt logprob indices.
# sample_indices to find sample indices.
"""
if sampling_params.prompt_logprobs is not None:
prompt_logprob_indices.extend(
range(logit_idx, logit_idx + prompt_logprob_len))
logit_idx += prompt_logprob_len
if do_sample:
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
categorized_sample_indices[sampling_params.sampling_type].extend(
list(
zip(range(logit_idx, logit_idx + sample_len),
range(sample_idx, sample_idx + sample_len))))
logit_idx += sample_len
sample_idx += sample_len
if sampling_params.seed is not None:
generator = seq_group_metadata.state.generator
seq_groups.append(
SequenceGroupToSample(
seq_ids=seq_ids,
sampling_params=sampling_params,
seq_data=seq_group_metadata.seq_data,
seq_len=seq_len,
query_len=query_len,
generator=generator,
is_prompt=is_prompt,
prompt_logprob_indices=list(prompt_logprob_indices),
sample_indices=list(sample_indices)))
return (seq_groups, selected_token_indices, categorized_sample_indices,
num_prompts)
@dataclass @dataclass
...@@ -112,11 +339,10 @@ class SamplingTensors: ...@@ -112,11 +339,10 @@ class SamplingTensors:
seeds_to_generate = (extra_seeds_to_generate + seeds_to_generate = (extra_seeds_to_generate +
get_num_triton_sampler_splits(vocab_size)) get_num_triton_sampler_splits(vocab_size))
sample_indices_start_idx = 0
assert sampling_metadata.seq_groups is not None assert sampling_metadata.seq_groups is not None
assert sampling_metadata.seq_data is not None for seq_group in sampling_metadata.seq_groups:
for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids = seq_group.seq_ids
seq_ids, sampling_params = seq_group sampling_params = seq_group.sampling_params
temperature = sampling_params.temperature temperature = sampling_params.temperature
p = sampling_params.presence_penalty p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty f = sampling_params.frequency_penalty
...@@ -145,45 +371,46 @@ class SamplingTensors: ...@@ -145,45 +371,46 @@ class SamplingTensors:
or abs(r - 1.0) >= _SAMPLING_EPS): or abs(r - 1.0) >= _SAMPLING_EPS):
do_penalties = True do_penalties = True
if (i < sampling_metadata.num_prompts is_prompt = seq_group.is_prompt
if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get # For tokens in the prompt that we only need to get
# their logprobs # their logprobs
assert sampling_metadata.prompt_lens is not None query_len = seq_group.query_len
prompt_len = sampling_metadata.prompt_lens[i] assert query_len is not None
temperatures += [temperature] * (prompt_len - 1) prefill_len = len(seq_group.prompt_logprob_indices)
top_ps += [top_p] * (prompt_len - 1) temperatures += [temperature] * prefill_len
top_ks += [top_k] * (prompt_len - 1) top_ps += [top_p] * prefill_len
min_ps += [min_p] * (prompt_len - 1) top_ks += [top_k] * prefill_len
presence_penalties += [0] * (prompt_len - 1) min_ps += [min_p] * prefill_len
frequency_penalties += [0] * (prompt_len - 1) presence_penalties += [0] * prefill_len
repetition_penalties += [1] * (prompt_len - 1) frequency_penalties += [0] * prefill_len
prompt_tokens.extend([] for _ in range(prompt_len - 1)) repetition_penalties += [1] * prefill_len
output_tokens.extend([] for _ in range(prompt_len - 1)) prompt_tokens.extend([] for _ in range(prefill_len))
for seq_id in seq_ids: output_tokens.extend([] for _ in range(prefill_len))
seq_data = sampling_metadata.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids) if seq_group.do_sample:
output_tokens.append(seq_data.output_token_ids) sample_lens = len(seq_group.sample_indices)
temperatures += [temperature] * len(seq_ids) assert sample_lens == len(seq_ids)
top_ps += [top_p] * len(seq_ids) for seq_id in seq_ids:
top_ks += [top_k] * len(seq_ids) seq_data = seq_group.seq_data[seq_id]
min_ps += [min_p] * len(seq_ids) prompt_tokens.append(seq_data.prompt_token_ids)
presence_penalties += [p] * len(seq_ids) output_tokens.append(seq_data.output_token_ids)
frequency_penalties += [f] * len(seq_ids) temperatures += [temperature] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids) top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
is_prompt = i < sampling_metadata.num_prompts min_ps += [min_p] * len(seq_ids)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
if is_prompt: if is_prompt:
prompt_best_of.append(sampling_params.best_of) prompt_best_of.append(sampling_params.best_of)
assert sampling_metadata.prompt_lens is not None query_len = seq_group.query_len
prompt_len = sampling_metadata.prompt_lens[i] assert query_len is not None
if sampling_params.prompt_logprobs is not None:
# NOTE: the sampling position is the last token
# in the prompt
sample_indices_start_idx += prompt_len - 1
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = sampling_metadata.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
extra_entropy = extra_entropy or () extra_entropy = extra_entropy or ()
seq_seeds = cls._get_sequence_seeds( seq_seeds = cls._get_sequence_seeds(
seed, seed,
...@@ -193,8 +420,7 @@ class SamplingTensors: ...@@ -193,8 +420,7 @@ class SamplingTensors:
seeds_to_generate=seeds_to_generate, seeds_to_generate=seeds_to_generate,
is_greedy=is_greedy) is_greedy=is_greedy)
sampling_seeds.append(seq_seeds) sampling_seeds.append(seq_seeds)
sample_indices.append(sample_indices_start_idx) sample_indices.extend(seq_group.sample_indices)
sample_indices_start_idx += 1
sampling_tensors = SamplingTensors.from_lists( sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties, temperatures, top_ps, top_ks, min_ps, presence_penalties,
...@@ -217,12 +443,14 @@ class SamplingTensors: ...@@ -217,12 +443,14 @@ class SamplingTensors:
# Note that the performance will be very bad without # Note that the performance will be very bad without
# pinned memory. # pinned memory.
pin_memory = is_pin_memory_available() pin_memory = is_pin_memory_available()
prompt_max_len = max(len(tokens) for tokens in prompt_tokens) prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
default=0)
prompt_padded_tokens = [ prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens)) tokens + [vocab_size] * (prompt_max_len - len(tokens))
for tokens in prompt_tokens for tokens in prompt_tokens
] ]
output_max_len = max(len(tokens) for tokens in output_tokens) output_max_len = max([len(tokens) for tokens in output_tokens],
default=0)
output_padded_tokens = [ output_padded_tokens = [
tokens + [vocab_size] * (output_max_len - len(tokens)) tokens + [vocab_size] * (output_max_len - len(tokens))
for tokens in output_tokens for tokens in output_tokens
......
...@@ -139,7 +139,10 @@ class SamplingParams: ...@@ -139,7 +139,10 @@ class SamplingParams:
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.min_p = min_p self.min_p = min_p
self.seed = seed if seed == -1:
self.seed = None
else:
self.seed = seed
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.length_penalty = length_penalty self.length_penalty = length_penalty
self.early_stopping = early_stopping self.early_stopping = early_stopping
...@@ -185,8 +188,8 @@ class SamplingParams: ...@@ -185,8 +188,8 @@ class SamplingParams:
self.top_k = -1 self.top_k = -1
self.min_p = 0.0 self.min_p = 0.0
self._verify_greedy_sampling() self._verify_greedy_sampling()
# injected by the engine # eos_token_id is added to this by the engine
self.eos_token_id = None self.all_stop_token_ids = set(self.stop_token_ids)
def _verify_args(self) -> None: def _verify_args(self) -> None:
if self.n < 1: if self.n < 1:
...@@ -275,7 +278,8 @@ class SamplingParams: ...@@ -275,7 +278,8 @@ class SamplingParams:
self, generation_config: Dict[str, Any]) -> None: self, generation_config: Dict[str, Any]) -> None:
"""Update if there are non-default values from generation_config""" """Update if there are non-default values from generation_config"""
# Update eos_token_id for generation # Update eos_token_id for generation
if eos_ids := generation_config.get("eos_token_id"): if (not self.ignore_eos) and (eos_ids :=
generation_config.get("eos_token_id")):
# it can be either int or list of int # it can be either int or list of int
if isinstance(eos_ids, int): if isinstance(eos_ids, int):
eos_ids = [eos_ids] eos_ids = [eos_ids]
......
"""Sequence and its related classes.""" """Sequence and its related classes."""
import copy import copy
import enum import enum
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Union from typing import TYPE_CHECKING, Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock from vllm.block import LogicalTokenBlock
...@@ -28,7 +28,10 @@ class Logprob: ...@@ -28,7 +28,10 @@ class Logprob:
decoded_token: Optional[str] = None decoded_token: Optional[str] = None
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs = List[Optional[Dict[int, Logprob]]] PromptLogprobs = List[Optional[Dict[int, Logprob]]]
# {token_id -> logprob} for each sequence group.
SampleLogprobs = List[Dict[int, Logprob]] SampleLogprobs = List[Dict[int, Logprob]]
...@@ -215,7 +218,7 @@ class Sequence: ...@@ -215,7 +218,7 @@ class Sequence:
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
self.data = SequenceData(prompt_token_ids) self.data: SequenceData = SequenceData(prompt_token_ids)
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
self.output_text = "" self.output_text = ""
...@@ -439,15 +442,27 @@ class SequenceGroup: ...@@ -439,15 +442,27 @@ class SequenceGroup:
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
def get_last_latency(self, now: float) -> float: def get_last_latency(self, now: float) -> Optional[float]:
"""Gets last token latency for Request level timings.""" """Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
if self.is_prefill():
raise ValueError(
"seq_group.get_last_latency() should not be called "
"if the seq_group is in prefill phase.")
# Otherwise return token latency.
latency = now - self.metrics.last_token_time latency = now - self.metrics.last_token_time
self.metrics.last_token_time = now self.metrics.last_token_time = now
return latency return latency
def maybe_set_first_token_time(self, time: float) -> None: def maybe_set_first_token_time(self, time: float) -> None:
"""Sets the first token time for Request level timings.""" """Sets the first token time for Request level timings."""
if self.metrics.first_token_time is None: # Note: in a case where a sequence_group is swapped and
# recomputed, the time between iterations is counted
# in TPOT, rather than recalculating TTFT (since from the )
# POV of the user, there is simply a long generation delay.
if (self.metrics.first_token_time is None
and self.get_seqs()[0].get_output_len() == 1):
self.metrics.first_token_time = time self.metrics.first_token_time = time
def maybe_set_first_scheduled_time(self, time: float) -> None: def maybe_set_first_scheduled_time(self, time: float) -> None:
...@@ -559,10 +574,15 @@ class SequenceGroupMetadata: ...@@ -559,10 +574,15 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block block_tables: The block tables. (Seq id -> list of physical block
numbers) numbers)
do_sample: True if sampling is required. Sampling is not required when
e.g., prefill is chunked, and the current iteration only computes
query tokens for prefill, we don't need sampling.
token_chunk_size: The number of tokens to be processed (per sequence). token_chunk_size: The number of tokens to be processed (per sequence).
None if chunking is not required. None if chunking is not required.
state: Internal state tied to this sequence group.
lora_request: LoRA request. lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data. multi_modal_data: Multi modal data.
""" """
...@@ -573,6 +593,7 @@ class SequenceGroupMetadata: ...@@ -573,6 +593,7 @@ class SequenceGroupMetadata:
seq_data: Dict[int, SequenceData], seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams, sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], block_tables: Dict[int, List[int]],
do_sample: bool = True,
token_chunk_size: Optional[int] = None, token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None, computed_block_nums: Optional[List[int]] = None,
...@@ -589,6 +610,7 @@ class SequenceGroupMetadata: ...@@ -589,6 +610,7 @@ class SequenceGroupMetadata:
self.multi_modal_data = multi_modal_data self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state self.state = SequenceGroupState() if state is None else state
self._token_chunk_size = token_chunk_size self._token_chunk_size = token_chunk_size
self.do_sample = do_sample
if self._token_chunk_size is None: if self._token_chunk_size is None:
if is_prompt: if is_prompt:
...@@ -650,6 +672,7 @@ class SequenceGroupOutput: ...@@ -650,6 +672,7 @@ class SequenceGroupOutput:
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
) -> None: ) -> None:
self.samples = samples self.samples = samples
# Prompt logprob for each prompt query token.
self.prompt_logprobs = prompt_logprobs self.prompt_logprobs = prompt_logprobs
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -677,6 +700,9 @@ class SamplerOutput: ...@@ -677,6 +700,9 @@ class SamplerOutput:
# On-device tensor containing probabilities of each token. # On-device tensor containing probabilities of each token.
sampled_token_probs: Optional["torch.Tensor"] = None sampled_token_probs: Optional["torch.Tensor"] = None
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
# On-device tensor containing the sampled token ids. # On-device tensor containing the sampled token ids.
sampled_token_ids: Optional["torch.Tensor"] = None sampled_token_ids: Optional["torch.Tensor"] = None
...@@ -708,3 +734,33 @@ class SamplerOutput: ...@@ -708,3 +734,33 @@ class SamplerOutput:
f"sampled_token_probs={sampled_token_probs_repr}, " f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr}, " f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
@dataclass
class ExecuteModelRequest:
"""The model execution request."""
# The sequence group metadata list.
seq_group_metadata_list: List[SequenceGroupMetadata]
# Blocks to swap in. Dict of CPU -> GPU block number.
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
# Blocks to copy. Source to a list of dest blocks.
blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict)
# The number of slots for lookahead decoding.
num_lookahead_slots: int = 0
# The number of requests in the running queue.
running_queue_size: int = 0
def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
) -> "ExecuteModelRequest":
"""Clone the request with a new sequence group metadata list."""
return ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=self.blocks_to_swap_in.copy(),
blocks_to_swap_out=self.blocks_to_swap_out.copy(),
blocks_to_copy=self.blocks_to_copy.copy(),
num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size,
)
from itertools import chain, count from itertools import chain, count
from typing import Dict, Iterator, List, Optional, Tuple from typing import Iterator, List, Tuple
import torch import torch
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
...@@ -40,11 +41,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -40,11 +41,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
@nvtx_range("BatchExpansionTop1Scorer.score_proposals") @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals( def score_proposals(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
) -> SpeculativeScores: ) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model. """Score the proposed tokens via the scorer model.
...@@ -57,11 +54,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -57,11 +54,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
no speculation is produced for that sequence. no speculation is produced for that sequence.
Args: Args:
seq_group_metadata_list: The input sequence group metadata. execute_model_req: The execution request.
blocks_to_swap_in: This is passed to the worker during scoring.
blocks_to_swap_out: This is passed to the worker during scoring.
blocks_to_copy: This is passed to the worker during scoring.
k: The fixed proposal length.
proposals: The speculative proposals to score. proposals: The speculative proposals to score.
Returns: Returns:
SpeculativeScores: The scores of each speculative token, along with SpeculativeScores: The scores of each speculative token, along with
...@@ -80,33 +73,31 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -80,33 +73,31 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
(spec_indices, non_spec_indices, target_seq_group_metadata_list, (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch( num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list_without_skips, proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list, proposal_lens_list=proposal_lens_list,
) )
target_sampler_output = self._scorer_worker.execute_model( target_sampler_output = self._scorer_worker.execute_model(
seq_group_metadata_list=target_seq_group_metadata_list, execute_model_req=execute_model_req.clone(
blocks_to_swap_in=blocks_to_swap_in, seq_group_metadata_list=target_seq_group_metadata_list, ))
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
assert len(target_sampler_output) == 1, "expected single-step output" assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0] target_sampler_output = target_sampler_output[0]
all_tokens, all_probs = self._contract_batch( all_tokens, all_probs, spec_logprobs = self._contract_batch(
contracted_bs=len(seq_group_metadata_list), contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output, target_sampler_output=target_sampler_output,
proposals=proposals, proposals=proposals,
num_scoring_tokens=num_scoring_tokens, num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices, non_spec_indices=non_spec_indices,
spec_indices=spec_indices, spec_indices=spec_indices,
k=k, k=execute_model_req.num_lookahead_slots,
) )
return SpeculativeScores( return SpeculativeScores(
probs=all_probs, probs=all_probs,
token_ids=all_tokens, token_ids=all_tokens,
logprobs=spec_logprobs,
) )
def _expand_batch( def _expand_batch(
...@@ -148,12 +139,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -148,12 +139,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return (spec_indices, non_spec_indices, target_seq_group_metadata_list, return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) num_scoring_tokens)
def _contract_batch(self, contracted_bs: int, def _contract_batch(
target_sampler_output: List[SamplerOutput], self, contracted_bs: int,
proposals: SpeculativeProposals, target_sampler_output: List[SamplerOutput],
num_scoring_tokens: int, non_spec_indices: List[int], proposals: SpeculativeProposals, num_scoring_tokens: int,
spec_indices: List[int], non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor]: k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Contract the expanded batch back into its original size. """Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original This maps the scores of speculative tokens back to their original
sequences. sequences.
...@@ -161,8 +152,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -161,8 +152,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
contracted_bs is the original batch size, and the batch size that the contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to. target_sampler_output will be contracted to.
""" """
(target_token_ids, target_probs, non_spec_target_token_ids, (target_token_ids, target_probs, target_logprobs,
non_spec_target_probs) = self._split_scoring_output( non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens) target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token # Map distinct sequences used to score each token
...@@ -179,6 +171,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -179,6 +171,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
spec_expanded_bs, k + 1) spec_expanded_bs, k + 1)
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1, target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size) self._vocab_size)
target_logprobs = target_logprobs.squeeze().reshape(
spec_expanded_bs, k + 1, self._vocab_size)
all_tokens = torch.full(size=(contracted_bs, k + 1), all_tokens = torch.full(size=(contracted_bs, k + 1),
fill_value=-1, fill_value=-1,
...@@ -189,16 +183,26 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -189,16 +183,26 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
self._vocab_size, self._vocab_size,
device=self._device, device=self._device,
dtype=torch.float32) dtype=torch.float32)
all_logprobs = torch.full(size=(
contracted_bs,
k + 1,
self._vocab_size,
),
fill_value=-float("inf"),
device=self._device,
dtype=torch.float32)
if non_spec_indices: if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs all_probs[non_spec_indices, :1, :] = non_spec_target_probs
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
if spec_indices: if spec_indices:
all_tokens[spec_indices] = target_token_ids all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs
return all_tokens, all_probs return all_tokens, all_probs, all_logprobs
def _create_scoring_model_input( def _create_scoring_model_input(
self, self,
...@@ -308,7 +312,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -308,7 +312,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _split_scoring_output( def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int self, sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
"""Split the target model output into speculative and non-speculative """Split the target model output into speculative and non-speculative
output. output.
""" """
...@@ -328,21 +333,29 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -328,21 +333,29 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
) = sampler_output.sampled_token_probs.split(split_sizes) ) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens (spec_sampled_tokens, non_spec_sampled_tokens
) = sampler_output.sampled_token_ids.flatten().split(split_sizes) ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
(
spec_logprobs,
non_spec_logprobs,
) = sampler_output.logprobs.split(split_sizes)
# Convert scores to tensors. # Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens sampler_output.sampled_token_ids = spec_sampled_tokens
target_token_ids, target_probs = sampler_output_to_torch( sampler_output.logprobs = spec_logprobs
[sampler_output]) (target_token_ids, target_probs,
target_logprobs) = sampler_output_to_torch([sampler_output], True)
# Convert non-speculative output tokens to tensors. # Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens sampler_output.sampled_token_ids = non_spec_sampled_tokens
non_spec_target_token_ids, non_spec_target_probs = ( sampler_output.logprobs = non_spec_logprobs
sampler_output_to_torch([sampler_output])) (non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
return (target_token_ids, target_probs, non_spec_target_token_ids, True)
non_spec_target_probs)
return (target_token_ids, target_probs, target_logprobs,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs)
def _create_target_seq_id_iterator( def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional
import torch import torch
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest
@dataclass @dataclass
...@@ -38,6 +37,11 @@ class SpeculativeScores: ...@@ -38,6 +37,11 @@ class SpeculativeScores:
# Probabilities of the speculative tokens according to the scoring model. # Probabilities of the speculative tokens according to the scoring model.
probs: torch.Tensor probs: torch.Tensor
# Log-probabilities of the speculative tokens according to the scoring
# model. These values can be used to generate Logprob objects that are
# returned to the user.
logprobs: torch.Tensor
# Token ids sampled from the scoring model. Used for speculative bonus # Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding. # tokens and also non-speculative normal decoding.
token_ids: torch.Tensor token_ids: torch.Tensor
...@@ -53,11 +57,7 @@ class SpeculativeProposer(ABC): ...@@ -53,11 +57,7 @@ class SpeculativeProposer(ABC):
@abstractmethod @abstractmethod
def get_proposals( def get_proposals(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals: ) -> SpeculativeProposals:
raise NotImplementedError raise NotImplementedError
...@@ -67,11 +67,7 @@ class SpeculativeScorer(ABC): ...@@ -67,11 +67,7 @@ class SpeculativeScorer(ABC):
@abstractmethod @abstractmethod
def score_proposals( def score_proposals(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
) -> SpeculativeScores: ) -> SpeculativeScores:
raise NotImplementedError raise NotImplementedError
import copy import copy
from typing import Dict, List, Optional, Tuple from typing import List, Tuple
import torch import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
from vllm.spec_decode.interfaces import (SpeculativeProposals, SequenceGroupMetadata)
SpeculativeProposer) from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.util import sampler_output_to_torch from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
...@@ -26,50 +26,53 @@ class MultiStepWorker(Worker): ...@@ -26,50 +26,53 @@ class MultiStepWorker(Worker):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Lazy initialization list. # Lazy initialization list.
self._proposer: DraftModelTop1Proposer self._proposer: Top1Proposer
def init_device(self): def init_device(self):
super().init_device() super().init_device()
self._proposer = DraftModelTop1Proposer( self._proposer = Top1Proposer(
self, self,
self.device, self.device,
self.max_model_len,
self.vocab_size, self.vocab_size,
max_proposal_len=self.max_model_len,
) )
def set_include_gpu_probs_tensor(self):
# Need include_gpu_probs_tensor for multi_step_worker
self.model_runner.model.sampler.include_gpu_probs_tensor = True
@torch.inference_mode() @torch.inference_mode()
def execute_model_multi_step( def sampler_output(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int], sample_len: int,
blocks_to_swap_out: Dict[int, int], ) -> Tuple[List[SamplerOutput], bool]:
blocks_to_copy: Dict[int, List[int]], """Run the model forward pass sample_len times. Returns the list of
num_steps: int, sampler output, one per model forward pass, along with indicator of
) -> List[SamplerOutput]: whether torch tensor in sampler output need to be transposed in latter
"""Run the model forward pass num_steps times. Returns the list of sampler_output_to_torch logic.
sampler output, one per model forward pass.
For multi step worker, this indicator shall be True.
""" """
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, self._raise_if_unsupported(execute_model_req)
blocks_to_swap_out, blocks_to_copy)
# Shallow copy input data so modifications (such as appending tokens) # Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects. # do not cause side-effects.
copied_seq_group_metadata_list = self._shallow_copy_inputs( copied_seq_group_metadata_list = self._shallow_copy_inputs(
seq_group_metadata_list) execute_model_req.seq_group_metadata_list)
copied_execute_model_req = execute_model_req.clone(
copied_seq_group_metadata_list)
# Assert enough KV space for num_steps tokens per sequence. # Assert enough KV space for sample_len tokens per sequence.
self._assert_enough_kv_space(seq_group_metadata_list, num_steps) self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
sample_len)
# Run model num_steps times. # Run model sample_len times.
model_outputs = [] model_outputs = []
for _ in range(num_steps): for _ in range(sample_len):
model_output = super().execute_model( model_output = super().execute_model(
seq_group_metadata_list=copied_seq_group_metadata_list, execute_model_req=copied_execute_model_req)
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
assert (len(model_output) == 1 assert (len(model_output) == 1
), "composing multistep workers not supported" ), "composing multistep workers not supported"
model_output = model_output[0] model_output = model_output[0]
...@@ -78,27 +81,17 @@ class MultiStepWorker(Worker): ...@@ -78,27 +81,17 @@ class MultiStepWorker(Worker):
copied_seq_group_metadata_list) copied_seq_group_metadata_list)
model_outputs.append(model_output) model_outputs.append(model_output)
return model_outputs return model_outputs, True
def get_spec_proposals( def get_spec_proposals(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals: ) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of """Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len. speculative tokens per sequence is determined by max_proposal_len.
""" """
return self._proposer.get_proposals( return self._proposer.get_proposals(execute_model_req)
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
max_proposal_len,
)
def _append_new_tokens( def _append_new_tokens(
self, model_output: SamplerOutput, self, model_output: SamplerOutput,
...@@ -189,188 +182,22 @@ class MultiStepWorker(Worker): ...@@ -189,188 +182,22 @@ class MultiStepWorker(Worker):
def _raise_if_unsupported( def _raise_if_unsupported(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None: ) -> None:
"""MultiStepWorker does not yet implement support for cache swap """MultiStepWorker does not yet implement support for cache swap
operations or beam search. operations or beam search.
""" """
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError( raise NotImplementedError(
"MultiStepWorker does not support cache operations") "MultiStepWorker does not support cache operations")
if any( if any(
len(seq_group_metadata.seq_data.keys()) != 1 len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in seq_group_metadata_list): for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError( raise NotImplementedError(
"MultiStepWorker does not support beam search.") "MultiStepWorker does not support beam search.")
class DraftModelTop1Proposer(SpeculativeProposer):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def __init__(
self,
draft_worker: MultiStepWorker,
device: str,
max_model_len: int,
vocab_size: int,
):
self._draft_worker = draft_worker
self._device = device
self._max_model_len = max_model_len
self._vocab_size = vocab_size
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
# Split speculative- and non-speculative- sequences.
(proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices) = self._split_by_max_model_len(
seq_group_metadata_list, max_proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
maybe_sampler_output = self._draft_worker.execute_model_multi_step(
seq_group_metadata_list=nonzero_proposal_len_seqs,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_steps=max_proposal_len,
)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
max_proposal_len=max_proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
)
proposals = SpeculativeProposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
return proposals
def _split_by_max_model_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
max_proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length.
"""
proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
if seq_len + max_proposal_len < self._max_model_len:
proposal_lens.append(max_proposal_len)
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
else:
proposal_lens.append(0)
return (proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices)
def _merge_outputs(
self,
batch_size: int,
max_proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.full(size=(
batch_size,
max_proposal_len,
),
fill_value=-1,
dtype=torch.long,
device=self._device)
proposal_probs = torch.zeros(batch_size,
max_proposal_len,
self._vocab_size,
dtype=torch.float32,
device=self._device)
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full(size=(batch_size,
*proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(batch_size,
*proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = (entire_proposal_tokens,
entire_proposal_probs)
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
return proposal_tokens, proposal_probs, proposal_lens_tensor
from typing import List, Optional, Tuple
import torch
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NGramWorker(LoraNotSupportedWorkerBase):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding,
and in future we may also do RAG type drafter and other scenerios
which don't rely on LLM model to give proposals.
"""
def __init__(self, *args, **kwargs):
# Get local_rank/vocab_size from kwargs attribute
self.local_rank = kwargs["local_rank"]
self.vocab_size = kwargs["model_config"].get_vocab_size()
# Lazy initialization list.
self._proposer: Top1Proposer
def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
ngram_prompt_lookup_max: int):
# Search valid candidate window between
# ngram_prompt_lookup_min/ngram_prompt_lookup_max
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
def init_device(self):
self.device = torch.device(f"cuda:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None
# Current only support Top1Proposer
self._proposer = Top1Proposer(
self,
device=self.device,
vocab_size=self.vocab_size,
)
def set_include_gpu_probs_tensor(self):
# NGram don't need gpu sampler
pass
def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
"""NGram doesn't depend on model execution, just pass this function"""
pass
def determine_num_available_blocks(self) -> None:
"""NGram doesn't depend on model execution, no need to check blocks"""
pass
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""As there is no cache need to handle, just pass this function"""
pass
def get_cache_block_size_bytes(self):
"""Return the size of a cache block in bytes."""
return 0
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[Optional[List[SamplerOutput]], bool]:
"""NGram match algo to pick proposal candidate. Returns the list of
sampler output, one per SequenceGroupMetadata.
For ngram worker, we already done needed transposed internal, so the
indicator pass to sampler_output_to_torch shall be False.
"""
self._raise_if_unsupported(execute_model_req)
arr = []
has_spec_out = False
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
seq_data = next(iter(seq_group_metadata.seq_data.values()))
input_ids = torch.as_tensor(seq_data.get_token_ids(),
dtype=torch.long,
device=self.device)
input_length = seq_data.get_len()
for ngram_size in range(
min(self.ngram_prompt_lookup_max, input_length - 1),
self.ngram_prompt_lookup_min,
-1,
):
ngram_tensor = input_ids[-1 * ngram_size:]
windows = input_ids.unfold(dimension=0,
size=ngram_size,
step=1)
matches = (windows == ngram_tensor).all(dim=1)
match_indices = matches.nonzero(as_tuple=True)[0]
if match_indices.size()[0] > 1:
has_spec_out = True
res = seq_data.get_token_ids()
res = res[match_indices[0] + ngram_size:match_indices[0] +
ngram_size + sample_len]
res_len = len(res)
# pad 0 towards output as sample_len tokens required
res += [0] * (sample_len - res_len)
break
else:
# if no candidate found, fill with 0
res = [0] * sample_len
arr.append(res)
if not has_spec_out:
return None, False
outputs = []
token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device)
indices = token_ids.unsqueeze(2)
token_probs = torch.zeros(
(len(execute_model_req.seq_group_metadata_list), sample_len,
self.vocab_size),
dtype=torch.float32,
device=self.device,
)
token_probs.scatter_(2, indices, 1)
token_logprobs = torch.zeros(
(len(execute_model_req.seq_group_metadata_list), sample_len,
self.vocab_size),
dtype=torch.float32,
device=self.device,
)
for i in range(len(execute_model_req.seq_group_metadata_list)):
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_probs=token_probs[i],
logprobs=token_logprobs,
sampled_token_ids=token_ids[i],
))
return outputs, False
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_proposals(execute_model_req)
def _raise_if_unsupported(
self,
execute_model_req: ExecuteModelRequest,
) -> None:
"""NGramWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"NGramWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"NGramWorker does not support beam search.")
from functools import cached_property from functools import cached_property
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupOutput, SequenceOutput) SequenceGroupMetadata)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_seq_ids,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
...@@ -48,8 +51,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -48,8 +51,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
""" """
@classmethod @classmethod
def from_workers(cls, proposer_worker: MultiStepWorker, def create_worker(
scorer_worker: WorkerBase) -> "SpecDecodeWorker": cls,
scorer_worker: WorkerBase,
draft_worker_kwargs,
) -> "SpecDecodeWorker":
if "ngram_prompt_lookup_max" in draft_worker_kwargs:
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
else:
ngram_prompt_lookup_max = 0
if ngram_prompt_lookup_max > 0:
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
else:
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
return SpecDecodeWorker( return SpecDecodeWorker(
proposer_worker, proposer_worker,
scorer_worker, scorer_worker,
...@@ -59,7 +81,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -59,7 +81,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def __init__( def __init__(
self, self,
proposer_worker: MultiStepWorker, proposer_worker: WorkerBase,
scorer_worker: WorkerBase, scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler, rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
...@@ -134,8 +156,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -134,8 +156,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
""" """
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True ) = True
(self.proposer_worker.model_runner.model.sampler. self.proposer_worker.set_include_gpu_probs_tensor()
include_gpu_probs_tensor) = True
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use. """Determine the number of cache blocks to use.
...@@ -169,68 +190,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -169,68 +190,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch. """Perform speculative decoding on the input batch.
""" """
assert seq_group_metadata_list is not None, ( assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding " "speculative decoding "
"requires non-None seq_group_metadata_list") "requires non-None seq_group_metadata_list")
logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}")
# If no spec tokens, call the proposer and scorer workers normally. # If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill. # Used for prefill.
if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0: if execute_model_req.num_lookahead_slots == 0 or len(
return self._run_no_spec( execute_model_req.seq_group_metadata_list) == 0:
seq_group_metadata_list=seq_group_metadata_list, return self._run_no_spec(execute_model_req)
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, return self._run_speculative_decoding_step(execute_model_req)
blocks_to_copy=blocks_to_copy,
)
return self._run_speculative_decoding_step(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
k=num_lookahead_slots,
)
@nvtx_range("spec_decode_worker._run_no_spec") @nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec( def _run_no_spec(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to the """Run a prefill step, without any speculation. The input is sent to the
proposer and scorer model so that the KV cache is consistent between the proposer and scorer model so that the KV cache is consistent between the
two. two.
""" """
logger.info("run proposer worker no spec") #logger.info("run proposer worker no spec")
self.proposer_worker.execute_model( self.proposer_worker.execute_model(execute_model_req)
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
logger.info("run target worker no spec") #logger.info("run target worker no spec")
sampler_output = self.scorer_worker.execute_model( sampler_output = self.scorer_worker.execute_model(execute_model_req)
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
assert len(sampler_output) == 1 assert len(sampler_output) == 1
sampler_output = sampler_output[0] sampler_output = sampler_output[0]
...@@ -238,17 +228,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -238,17 +228,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# overhead when the engine runs in a different process than the workers. # overhead when the engine runs in a different process than the workers.
sampler_output.probs = None sampler_output.probs = None
sampler_output.sampled_tokens = None sampler_output.sampled_tokens = None
sampler_output.logprobs = None
return [sampler_output] return [sampler_output]
@nvtx_range("spec_decode_worker._run_speculative_decoding_step") @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step( def _run_speculative_decoding_step(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
) -> List[SamplerOutput]:
"""Execute a single step of speculative decoding. """Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each This invokes the proposer worker to get k speculative tokens for each
...@@ -258,32 +244,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -258,32 +244,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence. sequence.
""" """
logger.info("get spec proposals") #logger.info("get spec proposals")
# Generate proposals using draft worker. # Generate proposals using draft worker.
assert blocks_to_swap_in is not None proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None #logger.info("score proposals")
proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)
logger.info("score proposals")
proposal_scores = self.scorer.score_proposals( proposal_scores = self.scorer.score_proposals(
seq_group_metadata_list, execute_model_req,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
k,
proposals, proposals,
) )
logger.info("verify proposals") #logger.info("verify proposals")
accepted_token_ids = self._verify_tokens(seq_group_metadata_list, accepted_token_ids, target_logprobs = self._verify_tokens(
proposal_scores, proposals, k) execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
logger.info("create output list") #logger.info("create output list")
return self._create_output_sampler_list(seq_group_metadata_list, return self._create_output_sampler_list(
accepted_token_ids, k) execute_model_req.seq_group_metadata_list,
accepted_token_ids,
target_logprobs=target_logprobs,
k=execute_model_req.num_lookahead_slots)
@nvtx_range("spec_decode_worker._verify_tokens") @nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens( def _verify_tokens(
...@@ -292,9 +273,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -292,9 +273,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores: SpeculativeScores, proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
max_proposal_len: int, max_proposal_len: int,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Determine which speculative tokens are accepted using the """Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models. probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
""" """
proposal_lens_list = proposals.proposal_lens.tolist() proposal_lens_list = proposals.proposal_lens.tolist()
...@@ -341,17 +325,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -341,17 +325,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
non_spec_token_ids[:, 1:] = -1 non_spec_token_ids[:, 1:] = -1
accepted_token_ids = torch.cat( accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids]) [accepted_token_ids, non_spec_token_ids])
logprobs = proposal_scores.logprobs
# Rearrange so that results are in the order of the original seq group # Rearrange so that results are in the order of the original seq group
# metadata. # metadata.
accepted_token_ids[original_indices] = accepted_token_ids.clone() accepted_token_ids[original_indices] = accepted_token_ids.clone()
return accepted_token_ids return accepted_token_ids, logprobs
def _create_output_sampler_list( def _create_output_sampler_list(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
k: int, k: int,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput. """Given the accepted token ids, create a list of SamplerOutput.
...@@ -359,30 +345,68 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -359,30 +345,68 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
The output is padded with -1 tokens such that each sequence has The output is padded with -1 tokens such that each sequence has
the same number of outputs. the same number of outputs.
""" """
batch_size, num_steps = accepted_token_ids.shape
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)
# Get the top-k logprobs (which may or may not include the logprob of
# the accepted token).
(topk_logprobs_by_step,
topk_indices_by_step) = target_logprobs_by_step.topk(
k=self.scorer_worker.model_config.max_logprobs,
dim=-1,
)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids = get_all_seq_ids(seq_group_metadata_list) seq_ids = get_all_seq_ids(seq_group_metadata_list)
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# shape: [k+1, batch_size]
accepted_token_ids_by_step = accepted_token_ids.transpose(0, # Serialize all tensors to CPU Python lists.
1).tolist() accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step.tolist())
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
topk_indices_by_step = topk_indices_by_step.tolist()
# Construct the output on a per-step, per-sequence basis.
sampler_output_list = [] sampler_output_list = []
for token_ids_by_step in accepted_token_ids_by_step: for step_index in range(num_steps):
if all(token_id == -1 for token_id in token_ids_by_step): if all(token_id == -1
for token_id in accepted_token_ids_by_step[step_index]):
break break
step_output_token_ids = [] step_output_token_ids = []
for token_id, seq_id in zip(token_ids_by_step, seq_ids): for sequence_index in range(batch_size):
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index]
step_output_token_ids.append( step_output_token_ids.append(
SequenceGroupOutput( create_sequence_group_output(
samples=[ token_id=accepted_token_ids_by_step[step_index]
SequenceOutput( [sequence_index],
parent_seq_id=seq_id, token_id_logprob_rank=accepted_token_id_ranks_by_step[
output_token=token_id, step_index][sequence_index],
# TODO Add verifier logprobs. token_id_logprob=accepted_token_id_logprobs_by_step[
logprobs={token_id: Logprob(0.0)}, step_index][sequence_index],
) seq_id=seq_ids[sequence_index],
], topk_token_ids=topk_indices_by_step[step_index]
prompt_logprobs=None, [sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs],
)) ))
sampler_output_list.append( sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids)) SamplerOutput(outputs=step_output_token_ids))
......
from typing import List, Optional, Tuple
import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.worker.worker_base import WorkerBase
class Top1Proposer(SpeculativeProposer):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def __init__(
self,
worker: WorkerBase,
device: str,
vocab_size: int,
max_proposal_len: Optional[int] = None,
):
self._worker = worker
self._device = device
self.max_proposal_len = max_proposal_len
self._vocab_size = vocab_size
def get_proposals(
self,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
proposal_len = execute_model_req.num_lookahead_slots
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
(
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
# If sampler_transposed is true, then maybe_sampler_output's
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
nonzero_execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=nonzero_proposal_len_seqs,
num_lookahead_slots=proposal_len,
)
maybe_sampler_output, transposed = self._worker.sampler_output(
execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len,
)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
transposed = False
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
proposal_len=proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
sampler_transposed=transposed,
)
proposals = SpeculativeProposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
return proposals
def _split_by_max_model_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length."""
proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall no exccess this
# quota for nonzero_proposal
if (self.max_proposal_len is None
or seq_len + proposal_len < self.max_proposal_len):
proposal_lens.append(proposal_len)
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
else:
proposal_lens.append(0)
return (
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
)
def _merge_outputs(
self,
batch_size: int,
proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.full(
size=(
batch_size,
proposal_len,
),
fill_value=-1,
dtype=torch.long,
device=self._device,
)
proposal_probs = torch.zeros(
batch_size,
proposal_len,
self._vocab_size,
dtype=torch.float32,
device=self._device,
)
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
sampler_output, sampler_transposed)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full(
size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device,
)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(
batch_size,
*proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device,
)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = (
entire_proposal_tokens,
entire_proposal_probs,
)
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
return proposal_tokens, proposal_probs, proposal_lens_tensor
from contextlib import contextmanager from contextlib import contextmanager
from itertools import chain from itertools import chain
from typing import List, Tuple from typing import Dict, List, Tuple
import torch import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
SeqId = int SeqId = int
...@@ -21,6 +22,89 @@ def get_all_seq_ids( ...@@ -21,6 +22,89 @@ def get_all_seq_ids(
])) ]))
def get_all_num_logprobs(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
If the sampling params do not call for any logprobs, return 0 for that
sequence.
"""
all_num_logprobs = []
for seq_group_metadata in seq_group_metadata_list:
num_logprobs = seq_group_metadata.sampling_params.logprobs
if seq_group_metadata.sampling_params.logprobs is None:
num_logprobs = 0
all_num_logprobs.append(num_logprobs)
return all_num_logprobs
def get_sampled_token_logprobs(
# shape [num_steps, batch_size, vocab_size]
logprob_tensor: torch.Tensor,
sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
"""
num_steps, batch_size, vocab_size = logprob_tensor.shape
selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1),
torch.arange(batch_size),
sampled_token_ids, ]
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
-1, -1, vocab_size)
sampled_token_ids_ranks = (logprob_tensor >=
expanded_selected_logprobs).sum(-1)
return sampled_token_ids_ranks, selected_logprobs
def create_sequence_group_output(
token_id: int,
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[int],
topk_logprobs: List[float],
) -> SequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[int]): The list of top-k token ids.
topk_logprobs (List[float]): The list of top-k logprobs.
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs: Dict[int, Logprob] = {
token_id: Logprob(
logprob=token_id_logprob,
rank=token_id_logprob_rank,
),
}
logprobs.update({
topk_token_ids[topk_logprob_index]: Logprob(
logprob=topk_logprobs[topk_logprob_index],
rank=topk_logprob_index + 1,
)
for topk_logprob_index, _ in enumerate(topk_token_ids)
})
return SequenceGroupOutput(
samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs=logprobs)
],
# TODO add prompt logprobs support.
prompt_logprobs=None,
)
def split_batch_by_proposal_len( def split_batch_by_proposal_len(
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_lens: List[int], select_proposal_len_zero: bool proposal_lens: List[int], select_proposal_len_zero: bool
...@@ -49,10 +133,13 @@ def split_batch_by_proposal_len( ...@@ -49,10 +133,13 @@ def split_batch_by_proposal_len(
def sampler_output_to_torch( def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput], sampler_output_list: List[SamplerOutput], sampler_transposed: bool
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Utility function which converts a list of SamplerOutput to tensors. """Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
we need do additional tensor transpose logic here.
Returns: Returns:
sampled_token_ids: torch.Tensor sampled_token_ids: torch.Tensor
shape: [batch_size, len(sampler_output_list)] shape: [batch_size, len(sampler_output_list)]
...@@ -68,7 +155,19 @@ def sampler_output_to_torch( ...@@ -68,7 +155,19 @@ def sampler_output_to_torch(
for sampler_output in sampler_output_list for sampler_output in sampler_output_list
], ],
dim=0, dim=0,
).transpose(0, 1) )
if sampler_transposed:
sampled_token_probs = sampled_token_probs.transpose(0, 1)
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs = torch.stack(
[sampler_output.logprobs for sampler_output in sampler_output_list],
dim=0,
)
if sampler_transposed:
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
# shape: [batch_size, num_sampler_output] # shape: [batch_size, num_sampler_output]
sampled_token_ids = torch.stack( sampled_token_ids = torch.stack(
...@@ -77,9 +176,11 @@ def sampler_output_to_torch( ...@@ -77,9 +176,11 @@ def sampler_output_to_torch(
for sampler_output in sampler_output_list for sampler_output in sampler_output_list
], ],
dim=0, dim=0,
).transpose(0, 1) )
if sampler_transposed:
sampled_token_ids = sampled_token_ids.transpose(0, 1)
return sampled_token_ids, sampled_token_probs return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
......
...@@ -72,9 +72,10 @@ class DbrxAttentionConfig(PretrainedConfig): ...@@ -72,9 +72,10 @@ class DbrxAttentionConfig(PretrainedConfig):
and config_dict["model_type"] != cls.model_type and config_dict["model_type"] != cls.model_type
): ):
logger.warning( logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " "You are using a model of type %s to instantiate a model of "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." "type %s. This is not supported for all configurations of "
) "models and can yield errors.",
config_dict["model_type"], cls.model_type)
return cls.from_dict(config_dict, **kwargs) return cls.from_dict(config_dict, **kwargs)
...@@ -151,9 +152,9 @@ class DbrxFFNConfig(PretrainedConfig): ...@@ -151,9 +152,9 @@ class DbrxFFNConfig(PretrainedConfig):
and config_dict["model_type"] != cls.model_type and config_dict["model_type"] != cls.model_type
): ):
logger.warning( logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " "You are using a model of type %s to instantiate a model of "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." "type %s. This is not supported for all "
) "configurations of models and can yield errors.", config_dict["model_type"], cls.model_type)
return cls.from_dict(config_dict, **kwargs) return cls.from_dict(config_dict, **kwargs)
......
import os import os
from typing import Optional, Union from typing import Optional, Union
import huggingface_hub
from transformers import (AutoTokenizer, PreTrainedTokenizer, from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
from vllm.config import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.transformers_utils.tokenizers import BaichuanTokenizer
...@@ -58,11 +59,12 @@ def get_tokenizer( ...@@ -58,11 +59,12 @@ def get_tokenizer(
*args, *args,
tokenizer_mode: str = "auto", tokenizer_mode: str = "auto",
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None, revision: Optional[str] = None,
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface/modelscope.""" """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
"""
if VLLM_USE_MODELSCOPE: if VLLM_USE_MODELSCOPE:
# download model from ModelScope hub, # download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use. # lazy import so that modelscope is not required for normal use.
...@@ -74,9 +76,10 @@ def get_tokenizer( ...@@ -74,9 +76,10 @@ def get_tokenizer(
tokenizer_path = snapshot_download( tokenizer_path = snapshot_download(
model_id=tokenizer_name, model_id=tokenizer_name,
cache_dir=download_dir, cache_dir=download_dir,
revision=tokenizer_revision, revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
# Ignore weights - we only need the tokenizer. # Ignore weights - we only need the tokenizer.
ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"]) ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
tokenizer_name = tokenizer_path tokenizer_name = tokenizer_path
if tokenizer_mode == "slow": if tokenizer_mode == "slow":
...@@ -90,7 +93,7 @@ def get_tokenizer( ...@@ -90,7 +93,7 @@ def get_tokenizer(
tokenizer_name, tokenizer_name,
*args, *args,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision, revision=revision,
**kwargs) **kwargs)
except ValueError as e: except ValueError as e:
# If the error pertains to the tokenizer class not existing or not # If the error pertains to the tokenizer class not existing or not
...@@ -114,7 +117,7 @@ def get_tokenizer( ...@@ -114,7 +117,7 @@ def get_tokenizer(
tokenizer_name, tokenizer_name,
*args, *args,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision, revision=revision,
**kwargs) **kwargs)
else: else:
raise e raise e
...@@ -137,9 +140,8 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, ...@@ -137,9 +140,8 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
# No tokenizer was found in the LoRA folder, # No tokenizer was found in the LoRA folder,
# use base model tokenizer # use base model tokenizer
logger.warning( logger.warning(
f"No tokenizer found in {lora_request.lora_local_path}, " "No tokenizer found in %s, using base model tokenizer instead. "
"using base model tokenizer instead. " "(Exception: %s)", lora_request.lora_local_path, e)
f"(Exception: {str(e)})")
tokenizer = None tokenizer = None
return tokenizer return tokenizer
......
from typing import Optional from typing import Optional
from vllm.config import TokenizerPoolConfig from vllm.config import TokenizerPoolConfig
from vllm.engine.ray_utils import ray from vllm.executor.ray_utils import ray
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup) BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
......
...@@ -6,7 +6,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy ...@@ -6,7 +6,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig from vllm.config import TokenizerPoolConfig
from vllm.engine.ray_utils import ray from vllm.executor.ray_utils import ray
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup) BaseTokenizerGroup)
......
...@@ -15,20 +15,22 @@ import psutil ...@@ -15,20 +15,22 @@ import psutil
import requests import requests
import torch import torch
_config_home = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) import vllm.envs as envs
_config_home = envs.VLLM_CONFIG_ROOT
_USAGE_STATS_JSON_PATH = os.path.join(_config_home, "vllm/usage_stats.json") _USAGE_STATS_JSON_PATH = os.path.join(_config_home, "vllm/usage_stats.json")
_USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home, _USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home,
"vllm/do_not_track") "vllm/do_not_track")
_USAGE_STATS_ENABLED = None _USAGE_STATS_ENABLED = None
_USAGE_STATS_SERVER = os.environ.get("VLLM_USAGE_STATS_SERVER", _USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER
"https://stats.vllm.ai")
def is_usage_stats_enabled(): def is_usage_stats_enabled():
"""Determine whether or not we can send usage stats to the server. """Determine whether or not we can send usage stats to the server.
The logic is as follows: The logic is as follows:
- By default, it should be enabled. - By default, it should be enabled.
- Two environment variables can disable it: - Three environment variables can disable it:
- VLLM_DO_NOT_TRACK=1
- DO_NOT_TRACK=1 - DO_NOT_TRACK=1
- VLLM_NO_USAGE_STATS=1 - VLLM_NO_USAGE_STATS=1
- A file in the home directory can disable it if it exists: - A file in the home directory can disable it if it exists:
...@@ -36,8 +38,8 @@ def is_usage_stats_enabled(): ...@@ -36,8 +38,8 @@ def is_usage_stats_enabled():
""" """
global _USAGE_STATS_ENABLED global _USAGE_STATS_ENABLED
if _USAGE_STATS_ENABLED is None: if _USAGE_STATS_ENABLED is None:
do_not_track = os.environ.get("DO_NOT_TRACK", "0") == "1" do_not_track = envs.VLLM_DO_NOT_TRACK
no_usage_stats = os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1" no_usage_stats = envs.VLLM_NO_USAGE_STATS
do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH) do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH)
_USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats _USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats
...@@ -167,7 +169,7 @@ class UsageMessage: ...@@ -167,7 +169,7 @@ class UsageMessage:
# Metadata # Metadata
self.log_time = _get_current_timestamp_ns() self.log_time = _get_current_timestamp_ns()
self.source = os.environ.get("VLLM_USAGE_SOURCE", "production") self.source = envs.VLLM_USAGE_SOURCE
data = vars(self) data = vars(self)
if extra_kvs: if extra_kvs:
......
import asyncio import asyncio
import datetime
import enum import enum
import gc import gc
import glob import glob
import os import os
import socket import socket
import subprocess import subprocess
import tempfile
import threading
import uuid import uuid
import warnings import warnings
from collections import defaultdict from collections import defaultdict
...@@ -18,7 +21,8 @@ import psutil ...@@ -18,7 +21,8 @@ import psutil
import torch import torch
from packaging.version import Version, parse from packaging.version import Version, parse
from vllm.logger import init_logger import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
T = TypeVar("T") T = TypeVar("T")
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -171,7 +175,7 @@ def get_vllm_instance_id(): ...@@ -171,7 +175,7 @@ def get_vllm_instance_id():
Instance id represents an instance of the VLLM. All processes in the same Instance id represents an instance of the VLLM. All processes in the same
instance should have the same instance id. instance should have the same instance id.
""" """
return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}") return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}"
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
...@@ -222,18 +226,25 @@ def merge_async_iterators( ...@@ -222,18 +226,25 @@ def merge_async_iterators(
] ]
async def consumer(): async def consumer():
while not all(finished) or not queue.empty(): try:
item = await queue.get() while not all(finished) or not queue.empty():
if isinstance(item, Exception): item = await queue.get()
raise item if isinstance(item, Exception):
yield item raise item
yield item
except (Exception, asyncio.CancelledError) as e:
for task in _tasks:
# NOTE: Pass the error msg in cancel()
# when only Python 3.9+ is supported.
task.cancel()
raise e
await asyncio.gather(*_tasks) await asyncio.gather(*_tasks)
return consumer() return consumer()
def get_ip() -> str: def get_ip() -> str:
host_ip = os.environ.get("HOST_IP") host_ip = envs.VLLM_HOST_IP
if host_ip: if host_ip:
return host_ip return host_ip
...@@ -259,7 +270,8 @@ def get_ip() -> str: ...@@ -259,7 +270,8 @@ def get_ip() -> str:
warnings.warn( warnings.warn(
"Failed to get the IP address, using 0.0.0.0 by default." "Failed to get the IP address, using 0.0.0.0 by default."
"The value can be set by the environment variable HOST_IP.", "The value can be set by the environment variable"
" VLLM_HOST_IP or HOST_IP.",
stacklevel=2) stacklevel=2)
return "0.0.0.0" return "0.0.0.0"
...@@ -286,8 +298,9 @@ def get_open_port() -> int: ...@@ -286,8 +298,9 @@ def get_open_port() -> int:
def update_environment_variables(envs: Dict[str, str]): def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items(): for k, v in envs.items():
if k in os.environ and os.environ[k] != v: if k in os.environ and os.environ[k] != v:
logger.warning(f"Overwriting environment variable {k} " logger.warning(
f"from '{os.environ[k]}' to '{v}'") "Overwriting environment variable %s "
"from '%s' to '%s'", k, os.environ[k], v)
os.environ[k] = v os.environ[k] = v
...@@ -303,15 +316,16 @@ def cdiv(a: int, b: int) -> int: ...@@ -303,15 +316,16 @@ def cdiv(a: int, b: int) -> int:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_nvcc_cuda_version() -> Optional[Version]: def get_nvcc_cuda_version() -> Optional[Version]:
cuda_home = os.environ.get('CUDA_HOME') cuda_home = envs.CUDA_HOME
if not cuda_home: if not cuda_home:
cuda_home = '/usr/local/cuda' cuda_home = '/usr/local/cuda'
if os.path.isfile(cuda_home + '/bin/nvcc'): if os.path.isfile(cuda_home + '/bin/nvcc'):
logger.info(f'CUDA_HOME is not found in the environment. ' logger.info(
f'Using {cuda_home} as CUDA_HOME.') 'CUDA_HOME is not found in the environment. '
'Using %s as CUDA_HOME.', cuda_home)
else: else:
logger.warning( logger.warning('Not found nvcc in %s. Skip cuda version check!',
f'Not found nvcc in {cuda_home}. Skip cuda version check!') cuda_home)
return None return None
nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
universal_newlines=True) universal_newlines=True)
...@@ -341,21 +355,9 @@ def _generate_random_fp8( ...@@ -341,21 +355,9 @@ def _generate_random_fp8(
del tensor_tmp del tensor_tmp
def create_kv_caches_with_random( def get_kv_cache_torch_dtype(
num_blocks: int, cache_dtype: Optional[Union[str, torch.dtype]],
block_size: int, model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
if isinstance(cache_dtype, str): if isinstance(cache_dtype, str):
if cache_dtype == "auto": if cache_dtype == "auto":
if isinstance(model_dtype, str): if isinstance(model_dtype, str):
...@@ -374,6 +376,55 @@ def create_kv_caches_with_random( ...@@ -374,6 +376,55 @@ def create_kv_caches_with_random(
torch_dtype = cache_dtype torch_dtype = cache_dtype
else: else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
return torch_dtype
def create_kv_caches_with_random_flash(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
assert cache_dtype != "fp8"
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
scale = head_size**-0.5
key_caches, value_caches = [], []
for _ in range(num_layers):
key_value_cache = torch.empty(size=key_value_cache_shape,
dtype=torch_dtype,
device=device)
key_value_cache.uniform_(-scale, scale)
key_caches.append(key_value_cache[:, 0])
value_caches.append(key_value_cache[:, 1])
return key_caches, value_caches
def create_kv_caches_with_random(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
scale = head_size**-0.5 scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size() x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
...@@ -569,7 +620,7 @@ def find_library(lib_name: str) -> str: ...@@ -569,7 +620,7 @@ def find_library(lib_name: str) -> str:
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line] locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
# `LD_LIBRARY_PATH` searches the library in the user-defined paths # `LD_LIBRARY_PATH` searches the library in the user-defined paths
env_ld_library_path = os.getenv("LD_LIBRARY_PATH") env_ld_library_path = envs.LD_LIBRARY_PATH
if not locs and env_ld_library_path: if not locs and env_ld_library_path:
locs = [ locs = [
os.path.join(dir, lib_name) os.path.join(dir, lib_name)
...@@ -582,22 +633,23 @@ def find_library(lib_name: str) -> str: ...@@ -582,22 +633,23 @@ def find_library(lib_name: str) -> str:
def find_nccl_library(): def find_nccl_library():
so_file = os.environ.get("VLLM_NCCL_SO_PATH", "") so_file = envs.VLLM_NCCL_SO_PATH
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
# check if we have vllm-managed nccl # check if we have vllm-managed nccl
vllm_nccl_path = None vllm_nccl_path = None
if torch.version.cuda is not None: if torch.version.cuda is not None:
cuda_major = torch.version.cuda.split(".")[0] cuda_major = torch.version.cuda.split(".")[0]
path = os.path.expanduser( path = os.path.expanduser(
f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*") f"{VLLM_CONFIG_ROOT}/vllm/nccl/cu{cuda_major}/libnccl.so.*")
files = glob.glob(path) files = glob.glob(path)
vllm_nccl_path = files[0] if files else None vllm_nccl_path = files[0] if files else None
# manually load the nccl library # manually load the nccl library
if so_file: if so_file:
logger.info( logger.info(
f"Found nccl from environment variable VLLM_NCCL_SO_PATH={so_file}" "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s",
) so_file)
else: else:
if torch.version.cuda is not None: if torch.version.cuda is not None:
so_file = vllm_nccl_path or find_library("libnccl.so.2") so_file = vllm_nccl_path or find_library("libnccl.so.2")
...@@ -605,5 +657,21 @@ def find_nccl_library(): ...@@ -605,5 +657,21 @@ def find_nccl_library():
so_file = find_library("librccl.so.1") so_file = find_library("librccl.so.1")
else: else:
raise ValueError("NCCL only supports CUDA and ROCm backends.") raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Found nccl from library {so_file}") logger.info("Found nccl from library %s", so_file)
return so_file return so_file
def enable_trace_function_call_for_thread() -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""
if envs.VLLM_TRACE_FUNCTION:
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -10,9 +10,8 @@ from vllm.distributed import broadcast_tensor_dict ...@@ -10,9 +10,8 @@ from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad
from vllm.utils import make_tensor_with_pad, maybe_expand_dim
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -38,6 +37,8 @@ class CPUModelRunner: ...@@ -38,6 +37,8 @@ class CPUModelRunner:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
# Currently, CPU worker doesn't support chunked prefill.
assert self.scheduler_config.chunked_prefill_enabled is False
self.lora_config = lora_config self.lora_config = lora_config
self.vision_language_config = vision_language_config self.vision_language_config = vision_language_config
self.load_config = load_config self.load_config = load_config
...@@ -79,7 +80,7 @@ class CPUModelRunner: ...@@ -79,7 +80,7 @@ class CPUModelRunner:
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
prompt_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = [] multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
...@@ -91,15 +92,15 @@ class CPUModelRunner: ...@@ -91,15 +92,15 @@ class CPUModelRunner:
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids() prompt_tokens = seq_data.get_token_ids()
computed_len = seq_data.get_num_computed_tokens() computed_len = seq_data.get_num_computed_tokens()
prompt_len = len(prompt_tokens) seq_len = len(prompt_tokens)
prompt_lens.append(prompt_len) # Prompt token num seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids input_tokens.extend(prompt_tokens) # Token ids
# Token position ids # Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prompt_len))) input_positions.extend(list(range(computed_len, seq_len)))
if seq_group_metadata.multi_modal_data: if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append( multi_modal_input_list.append(
...@@ -108,15 +109,15 @@ class CPUModelRunner: ...@@ -108,15 +109,15 @@ class CPUModelRunner:
# Compute the slot mapping. # Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window). # where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and # For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot # block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0 start_idx = 0
if self.sliding_window is not None: if self.sliding_window is not None:
start_idx = max(0, prompt_len - self.sliding_window) start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, prompt_len): for i in range(computed_len, seq_len):
if i < start_idx: if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID) slot_mapping.append(_PAD_SLOT_ID)
continue continue
...@@ -150,19 +151,19 @@ class CPUModelRunner: ...@@ -150,19 +151,19 @@ class CPUModelRunner:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
prompt_lens=prompt_lens, seq_lens=seq_lens,
num_prefills=len(prompt_lens), seq_lens_tensor=None,
max_seq_len=None,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens, num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0, num_decode_tokens=0,
prefill_metadata=None, prefill_metadata=None,
decode_metadata=None, decode_metadata=None,
max_context_len=None,
context_lens=None,
block_tables=torch.tensor([]), block_tables=torch.tensor([]),
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, attn_metadata, prompt_lens, return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input) multi_modal_input)
def _prepare_decode( def _prepare_decode(
...@@ -173,7 +174,7 @@ class CPUModelRunner: ...@@ -173,7 +174,7 @@ class CPUModelRunner:
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
context_lens: List[int] = [] seq_lens: List[int] = []
block_tables: List[List[int]] = [] block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
...@@ -191,9 +192,9 @@ class CPUModelRunner: ...@@ -191,9 +192,9 @@ class CPUModelRunner:
position = seq_len - 1 position = seq_len - 1
input_positions.append(position) input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min( seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window) seq_len, self.sliding_window)
context_lens.append(context_len) seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
...@@ -207,7 +208,7 @@ class CPUModelRunner: ...@@ -207,7 +208,7 @@ class CPUModelRunner:
block_table = block_table[-sliding_window_blocks:] block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table) block_tables.append(block_table)
max_context_len = max(context_lens) max_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens, input_tokens = torch.tensor(input_tokens,
dtype=torch.long, dtype=torch.long,
...@@ -218,9 +219,9 @@ class CPUModelRunner: ...@@ -218,9 +219,9 @@ class CPUModelRunner:
slot_mapping = torch.tensor(slot_mapping, slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
context_lens = torch.tensor(context_lens, seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
max_block_table_len = max( max_block_table_len = max(
len(block_table) for block_table in block_tables) len(block_table) for block_table in block_tables)
...@@ -235,14 +236,14 @@ class CPUModelRunner: ...@@ -235,14 +236,14 @@ class CPUModelRunner:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
prompt_lens=None, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_seq_len=max_seq_len,
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=len(input_tokens), num_decode_tokens=len(input_tokens),
max_context_len=max_context_len,
num_prefills=0, num_prefills=0,
prefill_metadata=None, prefill_metadata=None,
decode_metadata=None, decode_metadata=None,
context_lens=context_lens,
block_tables=block_tables, block_tables=block_tables,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
...@@ -252,99 +253,6 @@ class CPUModelRunner: ...@@ -252,99 +253,6 @@ class CPUModelRunner:
attn_metadata, attn_metadata,
) )
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
subquery_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx +
subquery_len - 1)
selected_token_start_idx += subquery_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx +
num_seqs)))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long)
categorized_sample_indices = {
t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
...@@ -357,15 +265,22 @@ class CPUModelRunner: ...@@ -357,15 +265,22 @@ class CPUModelRunner:
is_prompt = seq_group_metadata_list[0].is_prompt is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors. # Prepare input tensors.
if is_prompt: if is_prompt:
(input_tokens, input_positions, attn_metadata, prompt_lens, (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list) ) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, (input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list) attn_metadata) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = [] seq_lens = []
sampling_metadata = self._prepare_sample(seq_group_metadata_list, sampling_metadata = SamplingMetadata.prepare(
prompt_lens) seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
# Broadcast the metadata. # Broadcast the metadata.
metadata_dict = { metadata_dict = {
"input_tokens": input_tokens, "input_tokens": input_tokens,
...@@ -385,11 +300,10 @@ class CPUModelRunner: ...@@ -385,11 +300,10 @@ class CPUModelRunner:
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
seq_data=None, seq_data=None,
prompt_lens=None, seq_lens=None,
selected_token_indices=selected_token_indices, selected_token_indices=selected_token_indices,
categorized_sample_indices=None, categorized_sample_indices=None,
generators=None, generators=None,
perform_sampling=False,
) )
return (input_tokens, input_positions, attn_metadata, return (input_tokens, input_positions, attn_metadata,
...@@ -421,7 +335,7 @@ class CPUModelRunner: ...@@ -421,7 +335,7 @@ class CPUModelRunner:
logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker. # Only perform sampling in the driver worker.
if not sampling_metadata.perform_sampling: if not self.is_driver_worker:
return None return None
# Sample the next token. # Sample the next token.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment