Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1591c68f
Commit
1591c68f
authored
May 25, 2024
by
zhuwenwen
Browse files
merge v0.4.2
parents
09bcf00b
c7f2cf2b
Changes
265
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1295 additions
and
682 deletions
+1295
-682
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+15
-14
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+15
-16
vllm/model_executor/models/xverse.py
vllm/model_executor/models/xverse.py
+17
-16
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+293
-65
vllm/sampling_params.py
vllm/sampling_params.py
+8
-4
vllm/sequence.py
vllm/sequence.py
+62
-6
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+51
-38
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+8
-12
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+43
-216
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+176
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+124
-100
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+200
-0
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+108
-7
vllm/transformers_utils/configs/dbrx.py
vllm/transformers_utils/configs/dbrx.py
+7
-6
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+12
-10
vllm/transformers_utils/tokenizer_group/__init__.py
vllm/transformers_utils/tokenizer_group/__init__.py
+1
-1
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
...transformers_utils/tokenizer_group/ray_tokenizer_group.py
+1
-1
vllm/usage/usage_lib.py
vllm/usage/usage_lib.py
+9
-7
vllm/utils.py
vllm/utils.py
+105
-37
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+40
-126
No files found.
vllm/model_executor/models/stablelm.py
View file @
1591c68f
...
...
@@ -28,11 +28,12 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -46,7 +47,7 @@ class StablelmMLP(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -54,7 +55,7 @@ class StablelmMLP(nn.Module):
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
config
.
hidden_size
,
[
config
.
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
False
)
...
...
@@ -71,7 +72,7 @@ class StablelmAttention(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -109,11 +110,11 @@ class StablelmAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_key_value_heads
,
self
.
qkv_bias
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
rotary_ndims
,
...
...
@@ -145,11 +146,11 @@ class StablelmDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
self_attn
=
StablelmAttention
(
config
)
self
.
mlp
=
StablelmMLP
(
config
,
linear_method
)
self
.
mlp
=
StablelmMLP
(
config
,
quant_config
)
norm_eps
=
getattr
(
config
,
"norm_eps"
,
getattr
(
config
,
"layer_norm_eps"
,
1e-05
))
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
norm_eps
)
...
...
@@ -187,14 +188,14 @@ class StableLMEpochModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
StablelmDecoderLayer
(
config
,
linear_method
)
StablelmDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
norm_eps
=
getattr
(
config
,
"norm_eps"
,
...
...
@@ -226,12 +227,12 @@ class StablelmForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
StableLMEpochModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
model
=
StableLMEpochModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/starcoder2.py
View file @
1591c68f
...
...
@@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -45,7 +46,7 @@ class Starcoder2Attention(nn.Module):
def
__init__
(
self
,
config
:
Starcoder2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -79,13 +80,13 @@ class Starcoder2Attention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
self
.
use_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
self
.
use_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
...
...
@@ -121,21 +122,20 @@ class Starcoder2MLP(nn.Module):
def
__init__
(
self
,
config
:
Starcoder2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
c_fc
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
config
.
use_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
c_proj
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
config
.
use_bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
config
.
intermediate_size
)
...
...
@@ -150,12 +150,11 @@ class Starcoder2DecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
Starcoder2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
Starcoder2Attention
(
config
,
linear_method
=
linear_method
)
self
.
mlp
=
Starcoder2MLP
(
config
,
linear_method
=
linear_method
)
self
.
self_attn
=
Starcoder2Attention
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
Starcoder2MLP
(
config
,
quant_config
=
quant_config
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_epsilon
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
...
...
@@ -192,7 +191,7 @@ class Starcoder2Model(nn.Module):
def
__init__
(
self
,
config
:
Starcoder2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
...
...
@@ -202,7 +201,7 @@ class Starcoder2Model(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
Starcoder2DecoderLayer
(
config
,
linear_method
=
linear_method
)
Starcoder2DecoderLayer
(
config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_epsilon
)
...
...
@@ -227,10 +226,10 @@ class Starcoder2ForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
Starcoder2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
model
=
Starcoder2Model
(
config
,
linear_method
=
linear_method
)
self
.
model
=
Starcoder2Model
(
config
,
quant_config
=
quant_config
)
self
.
vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
config
.
tie_word_embeddings
:
...
...
vllm/model_executor/models/xverse.py
View file @
1591c68f
...
...
@@ -31,11 +31,12 @@ from vllm.config import LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -52,17 +53,17 @@ class XverseMLP(nn.Module):
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -85,7 +86,7 @@ class XverseAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
...
...
@@ -112,13 +113,13 @@ class XverseAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
bias
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
...
...
@@ -154,7 +155,7 @@ class XverseDecoderLayer(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -171,7 +172,7 @@ class XverseDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"bias"
,
False
),
sliding_window
=
sliding_window
,
)
...
...
@@ -179,7 +180,7 @@ class XverseDecoderLayer(nn.Module):
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -220,7 +221,7 @@ class XverseModel(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -236,7 +237,7 @@ class XverseModel(nn.Module):
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
layers
=
nn
.
ModuleList
([
XverseDecoderLayer
(
config
,
linear_method
)
XverseDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -294,13 +295,13 @@ class XverseForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
XverseModel
(
config
,
linear_method
)
self
.
quant_config
=
quant_config
self
.
model
=
XverseModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/sampling_metadata.py
View file @
1591c68f
...
...
@@ -6,57 +6,284 @@ import torch
from
vllm.model_executor.layers.ops.sample
import
get_num_triton_sampler_splits
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
from
vllm.utils
import
is_pin_memory_available
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
(
async_tensor_h2d
,
is_pin_memory_available
,
maybe_expand_dim
)
_SAMPLING_EPS
=
1e-5
_SEED_0_REPLACEMENT
=
3403598558
@
dataclass
class
SequenceGroupToSample
:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Sequence ids for the sequence group in a previous step.
seq_ids
:
List
[
int
]
sampling_params
:
SamplingParams
# seq_id -> sequence data.
seq_data
:
Dict
[
int
,
SequenceData
]
# The length of the sequence (all tokens seen in the past + new token to
# compute attention) of the sequence group. None if it is in a decode
# stage.
seq_len
:
Optional
[
int
]
# The length of new query tokens to compute in the current step. None if it
# is in a decode stage. The length of query_len <= seq_len if chunked
# prefill is enabled.
query_len
:
Optional
[
int
]
# A random number generator for sampling.
generator
:
Optional
[
torch
.
Generator
]
# True if the sequence group is in prefill stage. False if it is in a
# decode stage.
is_prompt
:
bool
# Query token indices from logits. to compute prompt logprob. Empty if
# prompt logprob is not required.
prompt_logprob_indices
:
List
[
int
]
# Sample token indices from logits. Empty if sampling is not required.
sample_indices
:
List
[
int
]
@
property
def
do_sample
(
self
):
return
len
(
self
.
sample_indices
)
>
0
def
__post_init__
(
self
):
if
len
(
self
.
prompt_logprob_indices
)
>
0
:
assert
self
.
sampling_params
.
prompt_logprobs
is
not
None
if
self
.
is_prompt
:
assert
self
.
seq_len
is
not
None
assert
self
.
query_len
is
not
None
class
SamplingMetadata
:
"""Metadata for input sequences. Used in sampler.
The usage is as follow;
```
hidden_states = execute_model(...)
logits = hidden_states[sampling_metadata.selected_token_indices]
sample(logits)
def sample(logits):
# Use categorized_sample_indices for sampling....
```
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
seq_groups: List of batched sequence groups.
selected_token_indices: (num_query_tokens_to_logprob). Indices to find
logits from the initial model output hidden states.
categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
Each token indices is 2D tensor of (num_indices, num_indices) where
the first item means the sample index within the returned logit
(before pruning padding), and the second item means the sample
index after pruning using selected_token_indices.
For example, if the returned logit is [1, 2, 3], and we select
[1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
"""
def
__init__
(
self
,
seq_groups
:
Optional
[
List
[
Tuple
[
List
[
int
],
SamplingParams
]]],
seq_data
:
Optional
[
Dict
[
int
,
SequenceData
]],
prompt_lens
:
Optional
[
List
[
int
]],
seq_groups
:
List
[
SequenceGroupToSample
],
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Optional
[
Dict
[
SamplingType
,
torch
.
Tensor
]],
generators
:
Optional
[
List
[
torch
.
Generator
]]
=
None
,
perform_sampling
:
bool
=
True
,
categorized_sample_indices
:
Dict
[
SamplingType
,
torch
.
Tensor
],
num_prompts
:
int
,
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_data
=
seq_data
self
.
prompt_lens
=
prompt_lens
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
generators
=
generators
self
.
perform_sampling
=
perform_sampling
self
.
num_prompts
=
num_prompts
self
.
num_prompts
=
len
(
prompt_lens
)
if
prompt_lens
is
not
None
else
0
@
staticmethod
def
prepare
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_lens
:
List
[
int
],
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
pin_memory
:
bool
,
)
->
"SamplingMetadata"
:
(
seq_groups
,
selected_token_indices
,
categorized_sample_indices
,
num_prompts
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
device
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
device
,
pin_memory
=
pin_memory
)
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
async_tensor_h2d
(
seq_ids
,
dtype
=
torch
.
int
,
target_device
=
device
,
pin_memory
=
pin_memory
),
2
,
2
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
num_prompts
=
num_prompts
,
)
return
sampling_metadata
def
__repr__
(
self
)
->
str
:
return
(
"SamplingMetadata("
f
"seq_groups=
{
self
.
seq_groups
}
, "
f
"seq_data=
{
self
.
seq_data
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
), "
f
"perform_sampling=
{
self
.
perform_sampling
}
)"
)
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
), "
)
def
_prepare_seq_groups
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_lens
:
List
[
int
],
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
"""Prepare sequence groups and indices for sampling.
Args:
seq_group_metadata_list: A list of sequence group to batch.
seq_lens: A list of sequence lens per sequence group.
Index of prompt len should match with seq_group_metadata_list.
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
`SequenceGroupToSample.generator`.
Returns:
seq_groups: A list of sequence group to sample.
selected_token_indices: See the definition from `SamplingMetadata`.
categorized_sample_indices: See the definition from `SamplingMetadata`.
num_prompts: Total number of prompts from `seq_group_metadata_list`.
"""
# Batched sequence groups for the current model forward stsep.
seq_groups
:
List
[
SequenceGroupToSample
]
=
[]
# A list of token indices to sample/compute logprob. It is used to
# prune the outcome logits from the model for the performance.
selected_token_indices
:
List
[
int
]
=
[]
# Used for selected_token_indices.
model_output_idx
=
0
# Sampling type -> (
# indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits)
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
# Index of logits to compute logprob. Logits include both prompt logprob
# and sample logprob indices.
logit_idx
=
0
# Index to sample from a sample tensor. It is used by triton sample kernel.
# See `_sample_with_triton_kernel` for more details.
sample_idx
=
0
# Total number of prompts from given sequence groups.
num_prompts
=
0
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
is_prompt
=
seq_group_metadata
.
is_prompt
generator
:
Optional
[
torch
.
Generator
]
=
None
# If the current seq group is in decode stage, it is None.
seq_len
:
Optional
[
int
]
=
None
query_len
:
Optional
[
int
]
=
None
prompt_logprob_indices
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
do_sample
=
seq_group_metadata
.
do_sample
if
seq_group_metadata
.
is_prompt
:
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
sampling_params
.
seed
)
num_prompts
+=
1
num_prefill_sample
=
len
(
seq_ids
)
assert
num_prefill_sample
==
1
assert
query_lens
is
not
None
and
seq_lens
is
not
None
query_len
,
seq_len
=
query_lens
[
i
],
seq_lens
[
i
]
# If we need sampling, exclude num_prefill_sample tokens from
# prompt logprob.
prompt_logprob_len
=
(
query_len
-
num_prefill_sample
if
do_sample
else
query_len
)
sample_len
=
num_prefill_sample
if
do_sample
else
0
else
:
# Decode
prompt_logprob_len
=
0
sample_len
=
len
(
seq_ids
)
if
do_sample
else
0
# Update indices to select from the model output.
"""
This blocks computes selected_token_indices which is used in the
following way.
hidden_states = model(...)
logits = hidden_states[selected_token_indices]
"""
if
sampling_params
.
prompt_logprobs
:
selected_token_indices
.
extend
(
range
(
model_output_idx
,
model_output_idx
+
prompt_logprob_len
))
model_output_idx
+=
prompt_logprob_len
if
do_sample
:
selected_token_indices
.
extend
(
range
(
model_output_idx
,
model_output_idx
+
sample_len
))
model_output_idx
+=
sample_len
# We now find indices for logprob computation and sampling.
"""
This block computes categorized_sample_indices which is used in the
following way.
hidden_states = model(...)
logits = hidden_states[selected_token_indices]
def sample(logits):
# Use categorized_sample_indices for sampling.
# prompt_logprob_indices to find prompt logprob indices.
# sample_indices to find sample indices.
"""
if
sampling_params
.
prompt_logprobs
is
not
None
:
prompt_logprob_indices
.
extend
(
range
(
logit_idx
,
logit_idx
+
prompt_logprob_len
))
logit_idx
+=
prompt_logprob_len
if
do_sample
:
sample_indices
.
extend
(
range
(
logit_idx
,
logit_idx
+
sample_len
))
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
list
(
zip
(
range
(
logit_idx
,
logit_idx
+
sample_len
),
range
(
sample_idx
,
sample_idx
+
sample_len
))))
logit_idx
+=
sample_len
sample_idx
+=
sample_len
if
sampling_params
.
seed
is
not
None
:
generator
=
seq_group_metadata
.
state
.
generator
seq_groups
.
append
(
SequenceGroupToSample
(
seq_ids
=
seq_ids
,
sampling_params
=
sampling_params
,
seq_data
=
seq_group_metadata
.
seq_data
,
seq_len
=
seq_len
,
query_len
=
query_len
,
generator
=
generator
,
is_prompt
=
is_prompt
,
prompt_logprob_indices
=
list
(
prompt_logprob_indices
),
sample_indices
=
list
(
sample_indices
)))
return
(
seq_groups
,
selected_token_indices
,
categorized_sample_indices
,
num_prompts
)
@
dataclass
...
...
@@ -112,11 +339,10 @@ class SamplingTensors:
seeds_to_generate
=
(
extra_seeds_to_generate
+
get_num_triton_sampler_splits
(
vocab_size
))
sample_indices_start_idx
=
0
assert
sampling_metadata
.
seq_groups
is
not
None
assert
sampling_metadata
.
seq_
data
is
not
None
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
for
seq_group
in
sampling_metadata
.
seq_
groups
:
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
temperature
=
sampling_params
.
temperature
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
...
...
@@ -145,45 +371,46 @@ class SamplingTensors:
or
abs
(
r
-
1.0
)
>=
_SAMPLING_EPS
):
do_penalties
=
True
if
(
i
<
sampling_metadata
.
num_prompts
is_prompt
=
seq_group
.
is_prompt
if
(
seq_group
.
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get
# their logprobs
assert
sampling_metadata
.
prompt_lens
is
not
None
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ks
+=
[
top_k
]
*
(
prompt_len
-
1
)
min_ps
+=
[
min_p
]
*
(
prompt_len
-
1
)
presence_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
frequency_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
repetition_penalties
+=
[
1
]
*
(
prompt_len
-
1
)
prompt_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
output_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
for
seq_id
in
seq_ids
:
seq_data
=
sampling_metadata
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
min_ps
+=
[
min_p
]
*
len
(
seq_ids
)
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
is_prompt
=
i
<
sampling_metadata
.
num_prompts
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
prefill_len
=
len
(
seq_group
.
prompt_logprob_indices
)
temperatures
+=
[
temperature
]
*
prefill_len
top_ps
+=
[
top_p
]
*
prefill_len
top_ks
+=
[
top_k
]
*
prefill_len
min_ps
+=
[
min_p
]
*
prefill_len
presence_penalties
+=
[
0
]
*
prefill_len
frequency_penalties
+=
[
0
]
*
prefill_len
repetition_penalties
+=
[
1
]
*
prefill_len
prompt_tokens
.
extend
([]
for
_
in
range
(
prefill_len
))
output_tokens
.
extend
([]
for
_
in
range
(
prefill_len
))
if
seq_group
.
do_sample
:
sample_lens
=
len
(
seq_group
.
sample_indices
)
assert
sample_lens
==
len
(
seq_ids
)
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
min_ps
+=
[
min_p
]
*
len
(
seq_ids
)
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
if
is_prompt
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
assert
sampling_metadata
.
prompt_lens
is
not
None
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: the sampling position is the last token
# in the prompt
sample_indices_start_idx
+=
prompt_len
-
1
for
seq_id
in
seq_ids
:
seq_data
=
s
ampling_metadata
.
seq_data
[
seq_id
]
seq_data
=
s
eq_group
.
seq_data
[
seq_id
]
extra_entropy
=
extra_entropy
or
()
seq_seeds
=
cls
.
_get_sequence_seeds
(
seed
,
...
...
@@ -193,8 +420,7 @@ class SamplingTensors:
seeds_to_generate
=
seeds_to_generate
,
is_greedy
=
is_greedy
)
sampling_seeds
.
append
(
seq_seeds
)
sample_indices
.
append
(
sample_indices_start_idx
)
sample_indices_start_idx
+=
1
sample_indices
.
extend
(
seq_group
.
sample_indices
)
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
...
...
@@ -217,12 +443,14 @@ class SamplingTensors:
# Note that the performance will be very bad without
# pinned memory.
pin_memory
=
is_pin_memory_available
()
prompt_max_len
=
max
(
len
(
tokens
)
for
tokens
in
prompt_tokens
)
prompt_max_len
=
max
([
len
(
tokens
)
for
tokens
in
prompt_tokens
],
default
=
0
)
prompt_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
for
tokens
in
prompt_tokens
]
output_max_len
=
max
(
len
(
tokens
)
for
tokens
in
output_tokens
)
output_max_len
=
max
([
len
(
tokens
)
for
tokens
in
output_tokens
],
default
=
0
)
output_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
output_max_len
-
len
(
tokens
))
for
tokens
in
output_tokens
...
...
vllm/sampling_params.py
View file @
1591c68f
...
...
@@ -139,7 +139,10 @@ class SamplingParams:
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
min_p
=
min_p
self
.
seed
=
seed
if
seed
==
-
1
:
self
.
seed
=
None
else
:
self
.
seed
=
seed
self
.
use_beam_search
=
use_beam_search
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
...
...
@@ -185,8 +188,8 @@ class SamplingParams:
self
.
top_k
=
-
1
self
.
min_p
=
0.0
self
.
_verify_greedy_sampling
()
#
injected
by the engine
self
.
eos
_token_id
=
None
#
eos_token_id is added to this
by the engine
self
.
all_stop
_token_id
s
=
set
(
self
.
stop_token_ids
)
def
_verify_args
(
self
)
->
None
:
if
self
.
n
<
1
:
...
...
@@ -275,7 +278,8 @@ class SamplingParams:
self
,
generation_config
:
Dict
[
str
,
Any
])
->
None
:
"""Update if there are non-default values from generation_config"""
# Update eos_token_id for generation
if
eos_ids
:
=
generation_config
.
get
(
"eos_token_id"
):
if
(
not
self
.
ignore_eos
)
and
(
eos_ids
:
=
generation_config
.
get
(
"eos_token_id"
)):
# it can be either int or list of int
if
isinstance
(
eos_ids
,
int
):
eos_ids
=
[
eos_ids
]
...
...
vllm/sequence.py
View file @
1591c68f
"""Sequence and its related classes."""
import
copy
import
enum
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Union
from
vllm.block
import
LogicalTokenBlock
...
...
@@ -28,7 +28,10 @@ class Logprob:
decoded_token
:
Optional
[
str
]
=
None
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs
=
List
[
Optional
[
Dict
[
int
,
Logprob
]]]
# {token_id -> logprob} for each sequence group.
SampleLogprobs
=
List
[
Dict
[
int
,
Logprob
]]
...
...
@@ -215,7 +218,7 @@ class Sequence:
self
.
eos_token_id
=
eos_token_id
self
.
lora_request
=
lora_request
self
.
data
=
SequenceData
(
prompt_token_ids
)
self
.
data
:
SequenceData
=
SequenceData
(
prompt_token_ids
)
self
.
output_logprobs
:
SampleLogprobs
=
[]
self
.
output_text
=
""
...
...
@@ -439,15 +442,27 @@ class SequenceGroup:
def
lora_int_id
(
self
)
->
int
:
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
def
get_last_latency
(
self
,
now
:
float
)
->
float
:
"""Gets last token latency for Request level timings."""
def
get_last_latency
(
self
,
now
:
float
)
->
Optional
[
float
]:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
if
self
.
is_prefill
():
raise
ValueError
(
"seq_group.get_last_latency() should not be called "
"if the seq_group is in prefill phase."
)
# Otherwise return token latency.
latency
=
now
-
self
.
metrics
.
last_token_time
self
.
metrics
.
last_token_time
=
now
return
latency
def
maybe_set_first_token_time
(
self
,
time
:
float
)
->
None
:
"""Sets the first token time for Request level timings."""
if
self
.
metrics
.
first_token_time
is
None
:
# Note: in a case where a sequence_group is swapped and
# recomputed, the time between iterations is counted
# in TPOT, rather than recalculating TTFT (since from the )
# POV of the user, there is simply a long generation delay.
if
(
self
.
metrics
.
first_token_time
is
None
and
self
.
get_seqs
()[
0
].
get_output_len
()
==
1
):
self
.
metrics
.
first_token_time
=
time
def
maybe_set_first_scheduled_time
(
self
,
time
:
float
)
->
None
:
...
...
@@ -559,10 +574,15 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
do_sample: True if sampling is required. Sampling is not required when
e.g., prefill is chunked, and the current iteration only computes
query tokens for prefill, we don't need sampling.
token_chunk_size: The number of tokens to be processed (per sequence).
None if chunking is not required.
state: Internal state tied to this sequence group.
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
"""
...
...
@@ -573,6 +593,7 @@ class SequenceGroupMetadata:
seq_data
:
Dict
[
int
,
SequenceData
],
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
do_sample
:
bool
=
True
,
token_chunk_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
...
...
@@ -589,6 +610,7 @@ class SequenceGroupMetadata:
self
.
multi_modal_data
=
multi_modal_data
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
self
.
_token_chunk_size
=
token_chunk_size
self
.
do_sample
=
do_sample
if
self
.
_token_chunk_size
is
None
:
if
is_prompt
:
...
...
@@ -650,6 +672,7 @@ class SequenceGroupOutput:
prompt_logprobs
:
Optional
[
PromptLogprobs
],
)
->
None
:
self
.
samples
=
samples
# Prompt logprob for each prompt query token.
self
.
prompt_logprobs
=
prompt_logprobs
def
__repr__
(
self
)
->
str
:
...
...
@@ -677,6 +700,9 @@ class SamplerOutput:
# On-device tensor containing probabilities of each token.
sampled_token_probs
:
Optional
[
"torch.Tensor"
]
=
None
# On-device tensor containing the logprobs of each token.
logprobs
:
Optional
[
"torch.Tensor"
]
=
None
# On-device tensor containing the sampled token ids.
sampled_token_ids
:
Optional
[
"torch.Tensor"
]
=
None
...
...
@@ -708,3 +734,33 @@ class SamplerOutput:
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
@
dataclass
class
ExecuteModelRequest
:
"""The model execution request."""
# The sequence group metadata list.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
# Blocks to swap in. Dict of CPU -> GPU block number.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dict
)
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dict
)
# Blocks to copy. Source to a list of dest blocks.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
field
(
default_factory
=
dict
)
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
=
0
# The number of requests in the running queue.
running_queue_size
:
int
=
0
def
clone
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
)
->
"ExecuteModelRequest"
:
"""Clone the request with a new sequence group metadata list."""
return
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
self
.
blocks_to_swap_in
.
copy
(),
blocks_to_swap_out
=
self
.
blocks_to_swap_out
.
copy
(),
blocks_to_copy
=
self
.
blocks_to_copy
.
copy
(),
num_lookahead_slots
=
self
.
num_lookahead_slots
,
running_queue_size
=
self
.
running_queue_size
,
)
vllm/spec_decode/batch_expansion.py
View file @
1591c68f
from
itertools
import
chain
,
count
from
typing
import
Dict
,
Iterator
,
List
,
Optional
,
Tuple
from
typing
import
Iterator
,
List
,
Tuple
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
...
...
@@ -40,11 +41,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
@
nvtx_range
(
"BatchExpansionTop1Scorer.score_proposals"
)
def
score_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
k
:
int
,
execute_model_req
:
ExecuteModelRequest
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
"""Score the proposed tokens via the scorer model.
...
...
@@ -57,11 +54,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
no speculation is produced for that sequence.
Args:
seq_group_metadata_list: The input sequence group metadata.
blocks_to_swap_in: This is passed to the worker during scoring.
blocks_to_swap_out: This is passed to the worker during scoring.
blocks_to_copy: This is passed to the worker during scoring.
k: The fixed proposal length.
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
...
...
@@ -80,33 +73,31 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
=
self
.
_expand_batch
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
,
proposal_token_ids_list
=
proposal_token_ids_list_without_skips
,
proposal_lens_list
=
proposal_lens_list
,
)
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
seq_group_metadata_list
=
target_seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
execute_model_req
=
execute_model_req
.
clone
(
seq_group_metadata_list
=
target_seq_group_metadata_list
,
))
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
all_tokens
,
all_probs
=
self
.
_contract_batch
(
contracted_bs
=
len
(
seq_group_metadata_list
),
all_tokens
,
all_probs
,
spec_logprobs
=
self
.
_contract_batch
(
contracted_bs
=
len
(
execute_model_req
.
seq_group_metadata_list
),
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
non_spec_indices
=
non_spec_indices
,
spec_indices
=
spec_indices
,
k
=
k
,
k
=
execute_model_req
.
num_lookahead_slots
,
)
return
SpeculativeScores
(
probs
=
all_probs
,
token_ids
=
all_tokens
,
logprobs
=
spec_logprobs
,
)
def
_expand_batch
(
...
...
@@ -148,12 +139,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
def
_contract_batch
(
self
,
contracted_bs
:
int
,
target_sampler_output
:
List
[
SamplerOutput
]
,
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
]
,
spec_indices
:
List
[
int
],
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
_contract_batch
(
self
,
contracted_bs
:
int
,
target_sampler_output
:
List
[
SamplerOutput
]
,
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
...
...
@@ -161,8 +152,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
non_spec_target_probs
)
=
self
.
_split_scoring_output
(
(
target_token_ids
,
target_probs
,
target_logprobs
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
)
=
self
.
_split_scoring_output
(
target_sampler_output
,
num_scoring_tokens
)
# Map distinct sequences used to score each token
...
...
@@ -179,6 +171,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
spec_expanded_bs
,
k
+
1
)
target_probs
=
target_probs
.
squeeze
().
reshape
(
spec_expanded_bs
,
k
+
1
,
self
.
_vocab_size
)
target_logprobs
=
target_logprobs
.
squeeze
().
reshape
(
spec_expanded_bs
,
k
+
1
,
self
.
_vocab_size
)
all_tokens
=
torch
.
full
(
size
=
(
contracted_bs
,
k
+
1
),
fill_value
=-
1
,
...
...
@@ -189,16 +183,26 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
self
.
_vocab_size
,
device
=
self
.
_device
,
dtype
=
torch
.
float32
)
all_logprobs
=
torch
.
full
(
size
=
(
contracted_bs
,
k
+
1
,
self
.
_vocab_size
,
),
fill_value
=-
float
(
"inf"
),
device
=
self
.
_device
,
dtype
=
torch
.
float32
)
if
non_spec_indices
:
all_tokens
[
non_spec_indices
,
:
1
]
=
non_spec_target_token_ids
all_probs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_probs
all_logprobs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_logprobs
if
spec_indices
:
all_tokens
[
spec_indices
]
=
target_token_ids
all_probs
[
spec_indices
]
=
target_probs
all_logprobs
[
spec_indices
]
=
target_logprobs
return
all_tokens
,
all_probs
return
all_tokens
,
all_probs
,
all_logprobs
def
_create_scoring_model_input
(
self
,
...
...
@@ -308,7 +312,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def
_split_scoring_output
(
self
,
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Split the target model output into speculative and non-speculative
output.
"""
...
...
@@ -328,21 +333,29 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
(
spec_sampled_tokens
,
non_spec_sampled_tokens
)
=
sampler_output
.
sampled_token_ids
.
flatten
().
split
(
split_sizes
)
(
spec_logprobs
,
non_spec_logprobs
,
)
=
sampler_output
.
logprobs
.
split
(
split_sizes
)
# Convert scores to tensors.
sampler_output
.
sampled_token_probs
=
spec_probs
sampler_output
.
sampled_token_ids
=
spec_sampled_tokens
target_token_ids
,
target_probs
=
sampler_output_to_torch
(
[
sampler_output
])
sampler_output
.
logprobs
=
spec_logprobs
(
target_token_ids
,
target_probs
,
target_logprobs
)
=
sampler_output_to_torch
([
sampler_output
],
True
)
# Convert non-speculative output tokens to tensors.
sampler_output
.
sampled_token_probs
=
non_spec_probs
sampler_output
.
sampled_token_ids
=
non_spec_sampled_tokens
non_spec_target_token_ids
,
non_spec_target_probs
=
(
sampler_output_to_torch
([
sampler_output
]))
return
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
non_spec_target_probs
)
sampler_output
.
logprobs
=
non_spec_logprobs
(
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
)
=
sampler_output_to_torch
([
sampler_output
],
True
)
return
(
target_token_ids
,
target_probs
,
target_logprobs
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
)
def
_create_target_seq_id_iterator
(
self
,
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
...
...
vllm/spec_decode/interfaces.py
View file @
1591c68f
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
import
torch
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
@
dataclass
...
...
@@ -38,6 +37,11 @@ class SpeculativeScores:
# Probabilities of the speculative tokens according to the scoring model.
probs
:
torch
.
Tensor
# Log-probabilities of the speculative tokens according to the scoring
# model. These values can be used to generate Logprob objects that are
# returned to the user.
logprobs
:
torch
.
Tensor
# Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding.
token_ids
:
torch
.
Tensor
...
...
@@ -53,11 +57,7 @@ class SpeculativeProposer(ABC):
@
abstractmethod
def
get_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
max_proposal_len
:
int
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
raise
NotImplementedError
...
...
@@ -67,11 +67,7 @@ class SpeculativeScorer(ABC):
@
abstractmethod
def
score_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
k
:
int
,
execute_model_req
:
ExecuteModelRequest
,
proposals
:
SpeculativeProposals
,
)
->
SpeculativeScores
:
raise
NotImplementedError
vllm/spec_decode/multi_step_worker.py
View file @
1591c68f
import
copy
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Tuple
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativePropos
er
)
from
vllm.spec_decode.
util
import
sampler_output_to_torch
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
SpeculativePropos
als
from
vllm.spec_decode.
top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
...
...
@@ -26,50 +26,53 @@ class MultiStepWorker(Worker):
super
().
__init__
(
*
args
,
**
kwargs
)
# Lazy initialization list.
self
.
_proposer
:
DraftModel
Top1Proposer
self
.
_proposer
:
Top1Proposer
def
init_device
(
self
):
super
().
init_device
()
self
.
_proposer
=
DraftModel
Top1Proposer
(
self
.
_proposer
=
Top1Proposer
(
self
,
self
.
device
,
self
.
max_model_len
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
)
def
set_include_gpu_probs_tensor
(
self
):
# Need include_gpu_probs_tensor for multi_step_worker
self
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
@
torch
.
inference_mode
()
def
execute_model_multi_step
(
def
sampler_output
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_steps
:
int
,
)
->
List
[
SamplerOutput
]:
"""Run the model forward pass num_steps times. Returns the list of
sampler output, one per model forward pass.
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self
.
_raise_if_unsupported
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
self
.
_raise_if_unsupported
(
execute_model_req
)
# Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects.
copied_seq_group_metadata_list
=
self
.
_shallow_copy_inputs
(
seq_group_metadata_list
)
execute_model_req
.
seq_group_metadata_list
)
copied_execute_model_req
=
execute_model_req
.
clone
(
copied_seq_group_metadata_list
)
# Assert enough KV space for num_steps tokens per sequence.
self
.
_assert_enough_kv_space
(
seq_group_metadata_list
,
num_steps
)
# Assert enough KV space for sample_len tokens per sequence.
self
.
_assert_enough_kv_space
(
execute_model_req
.
seq_group_metadata_list
,
sample_len
)
# Run model
num_steps
times.
# Run model
sample_len
times.
model_outputs
=
[]
for
_
in
range
(
num_steps
):
for
_
in
range
(
sample_len
):
model_output
=
super
().
execute_model
(
seq_group_metadata_list
=
copied_seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
execute_model_req
=
copied_execute_model_req
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
...
...
@@ -78,27 +81,17 @@ class MultiStepWorker(Worker):
copied_seq_group_metadata_list
)
model_outputs
.
append
(
model_output
)
return
model_outputs
return
model_outputs
,
True
def
get_spec_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
max_proposal_len
:
int
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
max_proposal_len
,
)
return
self
.
_proposer
.
get_proposals
(
execute_model_req
)
def
_append_new_tokens
(
self
,
model_output
:
SamplerOutput
,
...
...
@@ -189,188 +182,22 @@ class MultiStepWorker(Worker):
def
_raise_if_unsupported
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
execute_model_req
:
ExecuteModelRequest
,
)
->
None
:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if
any
([
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
]):
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
]):
raise
NotImplementedError
(
"MultiStepWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
seq_group_metadata_list
):
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
"MultiStepWorker does not support beam search."
)
class
DraftModelTop1Proposer
(
SpeculativeProposer
):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def
__init__
(
self
,
draft_worker
:
MultiStepWorker
,
device
:
str
,
max_model_len
:
int
,
vocab_size
:
int
,
):
self
.
_draft_worker
=
draft_worker
self
.
_device
=
device
self
.
_max_model_len
=
max_model_len
self
.
_vocab_size
=
vocab_size
def
get_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
max_proposal_len
:
int
,
)
->
SpeculativeProposals
:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
# Split speculative- and non-speculative- sequences.
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
)
=
self
.
_split_by_max_model_len
(
seq_group_metadata_list
,
max_proposal_len
)
if
nonzero_proposal_len_seqs
:
# Speculate tokens using the draft worker for the speculative
# sequences.
maybe_sampler_output
=
self
.
_draft_worker
.
execute_model_multi_step
(
seq_group_metadata_list
=
nonzero_proposal_len_seqs
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_steps
=
max_proposal_len
,
)
else
:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output
=
None
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens
,
proposal_probs
,
proposal_lens
=
self
.
_merge_outputs
(
batch_size
=
len
(
seq_group_metadata_list
),
max_proposal_len
=
max_proposal_len
,
maybe_sampler_output
=
maybe_sampler_output
,
proposal_lens
=
proposal_lens
,
nonzero_proposal_len_indices
=
nonzero_proposal_len_indices
,
)
proposals
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_tokens
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
,
)
return
proposals
def
_split_by_max_model_len
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
max_proposal_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
SequenceGroupMetadata
],
List
[
int
]]:
"""Determine which sequences would exceed the max model length.
"""
proposal_lens
:
List
[
int
]
=
[]
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_len
=
seq_data
.
get_len
()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
if
seq_len
+
max_proposal_len
<
self
.
_max_model_len
:
proposal_lens
.
append
(
max_proposal_len
)
nonzero_proposal_len_seqs
.
append
(
seq_group_metadata
)
nonzero_proposal_len_indices
.
append
(
i
)
else
:
proposal_lens
.
append
(
0
)
return
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
)
def
_merge_outputs
(
self
,
batch_size
:
int
,
max_proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
SamplerOutput
],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
tensor
,
torch
.
Tensor
]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
full
(
size
=
(
batch_size
,
max_proposal_len
,
),
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_probs
=
torch
.
zeros
(
batch_size
,
max_proposal_len
,
self
.
_vocab_size
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
)
proposal_lens_tensor
=
torch
.
zeros
(
len
(
proposal_lens
),
dtype
=
torch
.
long
,
device
=
self
.
_device
)
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
=
sampler_output_to_torch
(
sampler_output
)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
torch
.
full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
torch
.
zeros
(
batch_size
,
*
proposal_probs
.
shape
[
1
:],
dtype
=
torch
.
float32
,
device
=
self
.
_device
)
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
)
proposal_lens_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_lens_tensor
[
nonzero_proposal_len_indices
]
=
max_proposal_len
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
vllm/spec_decode/ngram_worker.py
0 → 100644
View file @
1591c68f
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
class
NGramWorker
(
LoraNotSupportedWorkerBase
):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding,
and in future we may also do RAG type drafter and other scenerios
which don't rely on LLM model to give proposals.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# Get local_rank/vocab_size from kwargs attribute
self
.
local_rank
=
kwargs
[
"local_rank"
]
self
.
vocab_size
=
kwargs
[
"model_config"
].
get_vocab_size
()
# Lazy initialization list.
self
.
_proposer
:
Top1Proposer
def
set_ngram_window_size
(
self
,
ngram_prompt_lookup_min
:
int
,
ngram_prompt_lookup_max
:
int
):
# Search valid candidate window between
# ngram_prompt_lookup_min/ngram_prompt_lookup_max
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
def
init_device
(
self
):
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
load_model
=
lambda
*
args
,
**
kwargs
:
None
# Current only support Top1Proposer
self
.
_proposer
=
Top1Proposer
(
self
,
device
=
self
.
device
,
vocab_size
=
self
.
vocab_size
,
)
def
set_include_gpu_probs_tensor
(
self
):
# NGram don't need gpu sampler
pass
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
None
:
"""NGram doesn't depend on model execution, just pass this function"""
pass
def
determine_num_available_blocks
(
self
)
->
None
:
"""NGram doesn't depend on model execution, no need to check blocks"""
pass
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""As there is no cache need to handle, just pass this function"""
pass
def
get_cache_block_size_bytes
(
self
):
"""Return the size of a cache block in bytes."""
return
0
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
)
->
Tuple
[
Optional
[
List
[
SamplerOutput
]],
bool
]:
"""NGram match algo to pick proposal candidate. Returns the list of
sampler output, one per SequenceGroupMetadata.
For ngram worker, we already done needed transposed internal, so the
indicator pass to sampler_output_to_torch shall be False.
"""
self
.
_raise_if_unsupported
(
execute_model_req
)
arr
=
[]
has_spec_out
=
False
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
:
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
input_ids
=
torch
.
as_tensor
(
seq_data
.
get_token_ids
(),
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_length
=
seq_data
.
get_len
()
for
ngram_size
in
range
(
min
(
self
.
ngram_prompt_lookup_max
,
input_length
-
1
),
self
.
ngram_prompt_lookup_min
,
-
1
,
):
ngram_tensor
=
input_ids
[
-
1
*
ngram_size
:]
windows
=
input_ids
.
unfold
(
dimension
=
0
,
size
=
ngram_size
,
step
=
1
)
matches
=
(
windows
==
ngram_tensor
).
all
(
dim
=
1
)
match_indices
=
matches
.
nonzero
(
as_tuple
=
True
)[
0
]
if
match_indices
.
size
()[
0
]
>
1
:
has_spec_out
=
True
res
=
seq_data
.
get_token_ids
()
res
=
res
[
match_indices
[
0
]
+
ngram_size
:
match_indices
[
0
]
+
ngram_size
+
sample_len
]
res_len
=
len
(
res
)
# pad 0 towards output as sample_len tokens required
res
+=
[
0
]
*
(
sample_len
-
res_len
)
break
else
:
# if no candidate found, fill with 0
res
=
[
0
]
*
sample_len
arr
.
append
(
res
)
if
not
has_spec_out
:
return
None
,
False
outputs
=
[]
token_ids
=
torch
.
as_tensor
(
arr
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
indices
=
token_ids
.
unsqueeze
(
2
)
token_probs
=
torch
.
zeros
(
(
len
(
execute_model_req
.
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
token_probs
.
scatter_
(
2
,
indices
,
1
)
token_logprobs
=
torch
.
zeros
(
(
len
(
execute_model_req
.
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
for
i
in
range
(
len
(
execute_model_req
.
seq_group_metadata_list
)):
outputs
.
append
(
SamplerOutput
(
outputs
=
None
,
sampled_token_probs
=
token_probs
[
i
],
logprobs
=
token_logprobs
,
sampled_token_ids
=
token_ids
[
i
],
))
return
outputs
,
False
def
get_spec_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_proposals
(
execute_model_req
)
def
_raise_if_unsupported
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
None
:
"""NGramWorker does not yet implement support for cache swap
operations or beam search.
"""
if
any
([
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
]):
raise
NotImplementedError
(
"NGramWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
"NGramWorker does not support beam search."
)
vllm/spec_decode/spec_decode_worker.py
View file @
1591c68f
from
functools
import
cached_property
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroup
Output
,
SequenceOutput
)
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroup
Metadata
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.util
import
(
create_sequence_group_output
,
get_all_num_logprobs
,
get_all_seq_ids
,
get_sampled_token_logprobs
,
nvtx_range
,
split_batch_by_proposal_len
)
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
,
WorkerBase
...
...
@@ -48,8 +51,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
@
classmethod
def
from_workers
(
cls
,
proposer_worker
:
MultiStepWorker
,
scorer_worker
:
WorkerBase
)
->
"SpecDecodeWorker"
:
def
create_worker
(
cls
,
scorer_worker
:
WorkerBase
,
draft_worker_kwargs
,
)
->
"SpecDecodeWorker"
:
if
"ngram_prompt_lookup_max"
in
draft_worker_kwargs
:
ngram_prompt_lookup_max
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_max"
))
ngram_prompt_lookup_min
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_min"
))
else
:
ngram_prompt_lookup_max
=
0
if
ngram_prompt_lookup_max
>
0
:
proposer_worker
=
NGramWorker
(
**
draft_worker_kwargs
)
proposer_worker
.
set_ngram_window_size
(
ngram_prompt_lookup_min
,
ngram_prompt_lookup_max
)
else
:
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
...
...
@@ -59,7 +81,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def
__init__
(
self
,
proposer_worker
:
MultiStep
Worker
,
proposer_worker
:
Worker
Base
,
scorer_worker
:
WorkerBase
,
rejection_sampler
:
RejectionSampler
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
...
...
@@ -134,8 +156,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
)
=
True
(
self
.
proposer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
)
=
True
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of cache blocks to use.
...
...
@@ -169,68 +190,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
"""Perform speculative decoding on the input batch.
"""
assert
seq_group_metadata_list
is
not
None
,
(
assert
execute_model_req
.
seq_group_metadata_list
is
not
None
,
(
"speculative decoding "
"requires non-None seq_group_metadata_list"
)
logger
.
info
(
f
"spec_decode_worker.execute_model
{
num_lookahead_slots
=
}
"
)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if
num_lookahead_slots
==
0
or
len
(
seq_group_metadata_list
)
==
0
:
return
self
.
_run_no_spec
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
return
self
.
_run_speculative_decoding_step
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
k
=
num_lookahead_slots
,
)
if
execute_model_req
.
num_lookahead_slots
==
0
or
len
(
execute_model_req
.
seq_group_metadata_list
)
==
0
:
return
self
.
_run_no_spec
(
execute_model_req
)
return
self
.
_run_speculative_decoding_step
(
execute_model_req
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
)
->
List
[
SamplerOutput
]:
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
"""Run a prefill step, without any speculation. The input is sent to the
proposer and scorer model so that the KV cache is consistent between the
two.
"""
logger
.
info
(
"run proposer worker no spec"
)
#
logger.info("run proposer worker no spec")
self
.
proposer_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
self
.
proposer_worker
.
execute_model
(
execute_model_req
)
logger
.
info
(
"run target worker no spec"
)
sampler_output
=
self
.
scorer_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
#logger.info("run target worker no spec")
sampler_output
=
self
.
scorer_worker
.
execute_model
(
execute_model_req
)
assert
len
(
sampler_output
)
==
1
sampler_output
=
sampler_output
[
0
]
...
...
@@ -238,17 +228,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# overhead when the engine runs in a different process than the workers.
sampler_output
.
probs
=
None
sampler_output
.
sampled_tokens
=
None
sampler_output
.
logprobs
=
None
return
[
sampler_output
]
@
nvtx_range
(
"spec_decode_worker._run_speculative_decoding_step"
)
def
_run_speculative_decoding_step
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
k
:
int
,
)
->
List
[
SamplerOutput
]:
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
...
...
@@ -258,32 +244,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence.
"""
logger
.
info
(
"get spec proposals"
)
#
logger.info("get spec proposals")
# Generate proposals using draft worker.
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
k
)
logger
.
info
(
"score proposals"
)
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
execute_model_req
)
#logger.info("score proposals")
proposal_scores
=
self
.
scorer
.
score_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
k
,
execute_model_req
,
proposals
,
)
logger
.
info
(
"verify proposals"
)
accepted_token_ids
=
self
.
_verify_tokens
(
seq_group_metadata_list
,
proposal_scores
,
proposals
,
k
)
#logger.info("verify proposals")
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
proposals
,
execute_model_req
.
num_lookahead_slots
)
logger
.
info
(
"create output list"
)
return
self
.
_create_output_sampler_list
(
seq_group_metadata_list
,
accepted_token_ids
,
k
)
#logger.info("create output list")
return
self
.
_create_output_sampler_list
(
execute_model_req
.
seq_group_metadata_list
,
accepted_token_ids
,
target_logprobs
=
target_logprobs
,
k
=
execute_model_req
.
num_lookahead_slots
)
@
nvtx_range
(
"spec_decode_worker._verify_tokens"
)
def
_verify_tokens
(
...
...
@@ -292,9 +273,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores
:
SpeculativeScores
,
proposals
:
SpeculativeProposals
,
max_proposal_len
:
int
,
)
->
torch
.
Tensor
:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list
=
proposals
.
proposal_lens
.
tolist
()
...
...
@@ -341,17 +325,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
non_spec_token_ids
[:,
1
:]
=
-
1
accepted_token_ids
=
torch
.
cat
(
[
accepted_token_ids
,
non_spec_token_ids
])
logprobs
=
proposal_scores
.
logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids
[
original_indices
]
=
accepted_token_ids
.
clone
()
return
accepted_token_ids
return
accepted_token_ids
,
logprobs
def
_create_output_sampler_list
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
accepted_token_ids
:
torch
.
Tensor
,
# shape: [batch_size, k+1]
target_logprobs
:
torch
.
Tensor
,
# shape: [batch_size, k+1, vocab_size]
k
:
int
,
)
->
List
[
SamplerOutput
]:
"""Given the accepted token ids, create a list of SamplerOutput.
...
...
@@ -359,30 +345,68 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
batch_size
,
num_steps
=
accepted_token_ids
.
shape
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step
=
target_logprobs
.
transpose
(
0
,
1
)
accepted_token_ids_by_step
=
accepted_token_ids
.
transpose
(
0
,
1
)
# Get the logprobs/rank of the accepted tokens.
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
)
=
get_sampled_token_logprobs
(
logprob_tensor
=
target_logprobs_by_step
,
sampled_token_ids
=
accepted_token_ids_by_step
,
)
# Get the top-k logprobs (which may or may not include the logprob of
# the accepted token).
(
topk_logprobs_by_step
,
topk_indices_by_step
)
=
target_logprobs_by_step
.
topk
(
k
=
self
.
scorer_worker
.
model_config
.
max_logprobs
,
dim
=-
1
,
)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
# shape: [k+1, batch_size]
accepted_token_ids_by_step
=
accepted_token_ids
.
transpose
(
0
,
1
).
tolist
()
num_logprobs_per_seq
=
get_all_num_logprobs
(
seq_group_metadata_list
)
# Serialize all tensors to CPU Python lists.
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
accepted_token_id_ranks_by_step
=
(
accepted_token_id_ranks_by_step
.
tolist
())
accepted_token_id_logprobs_by_step
=
(
accepted_token_id_logprobs_by_step
.
tolist
())
topk_logprobs_by_step
=
topk_logprobs_by_step
.
tolist
()
topk_indices_by_step
=
topk_indices_by_step
.
tolist
()
# Construct the output on a per-step, per-sequence basis.
sampler_output_list
=
[]
for
token_ids_by_step
in
accepted_token_ids_by_step
:
if
all
(
token_id
==
-
1
for
token_id
in
token_ids_by_step
):
for
step_index
in
range
(
num_steps
):
if
all
(
token_id
==
-
1
for
token_id
in
accepted_token_ids_by_step
[
step_index
]):
break
step_output_token_ids
=
[]
for
token_id
,
seq_id
in
zip
(
token_ids_by_step
,
seq_ids
):
for
sequence_index
in
range
(
batch_size
):
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs
=
num_logprobs_per_seq
[
sequence_index
]
step_output_token_ids
.
append
(
SequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
# TODO Add verifier logprobs.
logprobs
=
{
token_id
:
Logprob
(
0.0
)},
)
],
prompt_logprobs
=
None
,
create_sequence_group_output
(
token_id
=
accepted_token_ids_by_step
[
step_index
]
[
sequence_index
],
token_id_logprob_rank
=
accepted_token_id_ranks_by_step
[
step_index
][
sequence_index
],
token_id_logprob
=
accepted_token_id_logprobs_by_step
[
step_index
][
sequence_index
],
seq_id
=
seq_ids
[
sequence_index
],
topk_token_ids
=
topk_indices_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
topk_logprobs
=
topk_logprobs_by_step
[
step_index
]
[
sequence_index
][:
num_logprobs
],
))
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
step_output_token_ids
))
...
...
vllm/spec_decode/top1_proposer.py
0 → 100644
View file @
1591c68f
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.worker.worker_base
import
WorkerBase
class
Top1Proposer
(
SpeculativeProposer
):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def
__init__
(
self
,
worker
:
WorkerBase
,
device
:
str
,
vocab_size
:
int
,
max_proposal_len
:
Optional
[
int
]
=
None
,
):
self
.
_worker
=
worker
self
.
_device
=
device
self
.
max_proposal_len
=
max_proposal_len
self
.
_vocab_size
=
vocab_size
def
get_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
proposal_len
=
execute_model_req
.
num_lookahead_slots
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
,
)
=
self
.
_split_by_max_model_len
(
seq_group_metadata_list
,
proposal_len
)
if
nonzero_proposal_len_seqs
:
# Speculate tokens using the draft worker for the speculative
# sequences.
# If sampler_transposed is true, then maybe_sampler_output's
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
nonzero_execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
nonzero_proposal_len_seqs
,
num_lookahead_slots
=
proposal_len
,
)
maybe_sampler_output
,
transposed
=
self
.
_worker
.
sampler_output
(
execute_model_req
=
nonzero_execute_model_req
,
sample_len
=
proposal_len
,
)
else
:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output
=
None
transposed
=
False
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens
,
proposal_probs
,
proposal_lens
=
self
.
_merge_outputs
(
batch_size
=
len
(
seq_group_metadata_list
),
proposal_len
=
proposal_len
,
maybe_sampler_output
=
maybe_sampler_output
,
proposal_lens
=
proposal_lens
,
nonzero_proposal_len_indices
=
nonzero_proposal_len_indices
,
sampler_transposed
=
transposed
,
)
proposals
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_tokens
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
,
)
return
proposals
def
_split_by_max_model_len
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
SequenceGroupMetadata
],
List
[
int
]]:
"""Determine which sequences would exceed the max model length."""
proposal_lens
:
List
[
int
]
=
[]
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_len
=
seq_data
.
get_len
()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall no exccess this
# quota for nonzero_proposal
if
(
self
.
max_proposal_len
is
None
or
seq_len
+
proposal_len
<
self
.
max_proposal_len
):
proposal_lens
.
append
(
proposal_len
)
nonzero_proposal_len_seqs
.
append
(
seq_group_metadata
)
nonzero_proposal_len_indices
.
append
(
i
)
else
:
proposal_lens
.
append
(
0
)
return
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
,
)
def
_merge_outputs
(
self
,
batch_size
:
int
,
proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
SamplerOutput
],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
tensor
,
torch
.
Tensor
]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
full
(
size
=
(
batch_size
,
proposal_len
,
),
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
,
)
proposal_probs
=
torch
.
zeros
(
batch_size
,
proposal_len
,
self
.
_vocab_size
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
,
)
proposal_lens_tensor
=
torch
.
zeros
(
len
(
proposal_lens
),
dtype
=
torch
.
long
,
device
=
self
.
_device
)
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
,
_
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
torch
.
full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
,
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
torch
.
zeros
(
batch_size
,
*
proposal_probs
.
shape
[
1
:],
dtype
=
torch
.
float32
,
device
=
self
.
_device
,
)
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
,
)
proposal_lens_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_lens_tensor
[
nonzero_proposal_len_indices
]
=
proposal_len
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
vllm/spec_decode/util.py
View file @
1591c68f
from
contextlib
import
contextmanager
from
itertools
import
chain
from
typing
import
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
)
SeqId
=
int
...
...
@@ -21,6 +22,89 @@ def get_all_seq_ids(
]))
def
get_all_num_logprobs
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
List
[
int
]:
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
If the sampling params do not call for any logprobs, return 0 for that
sequence.
"""
all_num_logprobs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
num_logprobs
=
seq_group_metadata
.
sampling_params
.
logprobs
if
seq_group_metadata
.
sampling_params
.
logprobs
is
None
:
num_logprobs
=
0
all_num_logprobs
.
append
(
num_logprobs
)
return
all_num_logprobs
def
get_sampled_token_logprobs
(
# shape [num_steps, batch_size, vocab_size]
logprob_tensor
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
# shape [num_steps, batch_size]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
"""
num_steps
,
batch_size
,
vocab_size
=
logprob_tensor
.
shape
selected_logprobs
=
logprob_tensor
[
torch
.
arange
(
num_steps
).
unsqueeze
(
1
),
torch
.
arange
(
batch_size
),
sampled_token_ids
,
]
expanded_selected_logprobs
=
selected_logprobs
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
vocab_size
)
sampled_token_ids_ranks
=
(
logprob_tensor
>=
expanded_selected_logprobs
).
sum
(
-
1
)
return
sampled_token_ids_ranks
,
selected_logprobs
def
create_sequence_group_output
(
token_id
:
int
,
token_id_logprob_rank
:
int
,
token_id_logprob
:
float
,
seq_id
:
SeqId
,
topk_token_ids
:
List
[
int
],
topk_logprobs
:
List
[
float
],
)
->
SequenceGroupOutput
:
"""Create a SequenceGroupOutput given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[int]): The list of top-k token ids.
topk_logprobs (List[float]): The list of top-k logprobs.
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs
:
Dict
[
int
,
Logprob
]
=
{
token_id
:
Logprob
(
logprob
=
token_id_logprob
,
rank
=
token_id_logprob_rank
,
),
}
logprobs
.
update
({
topk_token_ids
[
topk_logprob_index
]:
Logprob
(
logprob
=
topk_logprobs
[
topk_logprob_index
],
rank
=
topk_logprob_index
+
1
,
)
for
topk_logprob_index
,
_
in
enumerate
(
topk_token_ids
)
})
return
SequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
logprobs
=
logprobs
)
],
# TODO add prompt logprobs support.
prompt_logprobs
=
None
,
)
def
split_batch_by_proposal_len
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_lens
:
List
[
int
],
select_proposal_len_zero
:
bool
...
...
@@ -49,10 +133,13 @@ def split_batch_by_proposal_len(
def
sampler_output_to_torch
(
sampler_output_list
:
List
[
SamplerOutput
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
sampler_output_list
:
List
[
SamplerOutput
],
sampler_transposed
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
we need do additional tensor transpose logic here.
Returns:
sampled_token_ids: torch.Tensor
shape: [batch_size, len(sampler_output_list)]
...
...
@@ -68,7 +155,19 @@ def sampler_output_to_torch(
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
).
transpose
(
0
,
1
)
)
if
sampler_transposed
:
sampled_token_probs
=
sampled_token_probs
.
transpose
(
0
,
1
)
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs
=
torch
.
stack
(
[
sampler_output
.
logprobs
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_token_logprobs
=
sampled_token_logprobs
.
transpose
(
0
,
1
)
# shape: [batch_size, num_sampler_output]
sampled_token_ids
=
torch
.
stack
(
...
...
@@ -77,9 +176,11 @@ def sampler_output_to_torch(
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
).
transpose
(
0
,
1
)
)
if
sampler_transposed
:
sampled_token_ids
=
sampled_token_ids
.
transpose
(
0
,
1
)
return
sampled_token_ids
,
sampled_token_probs
return
sampled_token_ids
,
sampled_token_probs
,
sampled_token_logprobs
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
...
...
vllm/transformers_utils/configs/dbrx.py
View file @
1591c68f
...
...
@@ -72,9 +72,10 @@ class DbrxAttentionConfig(PretrainedConfig):
and
config_dict
[
"model_type"
]
!=
cls
.
model_type
):
logger
.
warning
(
f
"You are using a model of type
{
config_dict
[
'model_type'
]
}
to instantiate a model of type "
+
f
"
{
cls
.
model_type
}
. This is not supported for all configurations of models and can yield errors."
)
"You are using a model of type %s to instantiate a model of "
"type %s. This is not supported for all configurations of "
"models and can yield errors."
,
config_dict
[
"model_type"
],
cls
.
model_type
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
...
...
@@ -151,9 +152,9 @@ class DbrxFFNConfig(PretrainedConfig):
and
config_dict
[
"model_type"
]
!=
cls
.
model_type
):
logger
.
warning
(
f
"You are using a model of type
{
config_dict
[
'model_type'
]
}
to instantiate a model of
type
"
+
f
"
{
cls
.
model_
type
}
. This is not supported for all
configurations of models and can yield errors.
"
)
"You are using a model of type
%s
to instantiate a model of "
"
type
%s
. This is not supported for all "
"configurations of models and can yield errors."
,
config_dict
[
"model_type"
],
cls
.
model_type
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
...
...
vllm/transformers_utils/tokenizer.py
View file @
1591c68f
import
os
from
typing
import
Optional
,
Union
import
huggingface_hub
from
transformers
import
(
AutoTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerFast
)
from
vllm.
config
import
VLLM_USE_MODELSCOPE
from
vllm.
envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizers
import
BaichuanTokenizer
...
...
@@ -58,11 +59,12 @@ def get_tokenizer(
*
args
,
tokenizer_mode
:
str
=
"auto"
,
trust_remote_code
:
bool
=
False
,
tokenizer_
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
download_dir
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
"""Gets a tokenizer for the given model name via Huggingface/modelscope."""
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.
"""
if
VLLM_USE_MODELSCOPE
:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
...
...
@@ -74,9 +76,10 @@ def get_tokenizer(
tokenizer_path
=
snapshot_download
(
model_id
=
tokenizer_name
,
cache_dir
=
download_dir
,
revision
=
tokenizer_revision
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
# Ignore weights - we only need the tokenizer.
ignore_file_pattern
=
[
"*.pt"
,
"*.safetensors"
,
"*.bin"
])
ignore_file_pattern
=
[
"
.
*.pt"
,
"
.
*.safetensors"
,
"
.
*.bin"
])
tokenizer_name
=
tokenizer_path
if
tokenizer_mode
==
"slow"
:
...
...
@@ -90,7 +93,7 @@ def get_tokenizer(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
tokenizer_revision
=
tokenizer_
revision
,
revision
=
revision
,
**
kwargs
)
except
ValueError
as
e
:
# If the error pertains to the tokenizer class not existing or not
...
...
@@ -114,7 +117,7 @@ def get_tokenizer(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
tokenizer_revision
=
tokenizer_
revision
,
revision
=
revision
,
**
kwargs
)
else
:
raise
e
...
...
@@ -137,9 +140,8 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger
.
warning
(
f
"No tokenizer found in
{
lora_request
.
lora_local_path
}
, "
"using base model tokenizer instead. "
f
"(Exception:
{
str
(
e
)
}
)"
)
"No tokenizer found in %s, using base model tokenizer instead. "
"(Exception: %s)"
,
lora_request
.
lora_local_path
,
e
)
tokenizer
=
None
return
tokenizer
...
...
vllm/transformers_utils/tokenizer_group/__init__.py
View file @
1591c68f
from
typing
import
Optional
from
vllm.config
import
TokenizerPoolConfig
from
vllm.e
ngine
.ray_utils
import
ray
from
vllm.e
xecutor
.ray_utils
import
ray
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
from
vllm.transformers_utils.tokenizer_group.tokenizer_group
import
(
...
...
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
View file @
1591c68f
...
...
@@ -6,7 +6,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
TokenizerPoolConfig
from
vllm.e
ngine
.ray_utils
import
ray
from
vllm.e
xecutor
.ray_utils
import
ray
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
...
...
vllm/usage/usage_lib.py
View file @
1591c68f
...
...
@@ -15,20 +15,22 @@ import psutil
import
requests
import
torch
_config_home
=
os
.
getenv
(
"XDG_CONFIG_HOME"
,
os
.
path
.
expanduser
(
"~/.config"
))
import
vllm.envs
as
envs
_config_home
=
envs
.
VLLM_CONFIG_ROOT
_USAGE_STATS_JSON_PATH
=
os
.
path
.
join
(
_config_home
,
"vllm/usage_stats.json"
)
_USAGE_STATS_DO_NOT_TRACK_PATH
=
os
.
path
.
join
(
_config_home
,
"vllm/do_not_track"
)
_USAGE_STATS_ENABLED
=
None
_USAGE_STATS_SERVER
=
os
.
environ
.
get
(
"VLLM_USAGE_STATS_SERVER"
,
"https://stats.vllm.ai"
)
_USAGE_STATS_SERVER
=
envs
.
VLLM_USAGE_STATS_SERVER
def
is_usage_stats_enabled
():
"""Determine whether or not we can send usage stats to the server.
The logic is as follows:
- By default, it should be enabled.
- Two environment variables can disable it:
- Three environment variables can disable it:
- VLLM_DO_NOT_TRACK=1
- DO_NOT_TRACK=1
- VLLM_NO_USAGE_STATS=1
- A file in the home directory can disable it if it exists:
...
...
@@ -36,8 +38,8 @@ def is_usage_stats_enabled():
"""
global
_USAGE_STATS_ENABLED
if
_USAGE_STATS_ENABLED
is
None
:
do_not_track
=
os
.
environ
.
get
(
"DO_NOT_TRACK"
,
"0"
)
==
"1"
no_usage_stats
=
os
.
environ
.
get
(
"
VLLM_NO_USAGE_STATS
"
,
"0"
)
==
"1"
do_not_track
=
envs
.
VLLM_DO_NOT_TRACK
no_usage_stats
=
envs
.
VLLM_NO_USAGE_STATS
do_not_track_file
=
os
.
path
.
exists
(
_USAGE_STATS_DO_NOT_TRACK_PATH
)
_USAGE_STATS_ENABLED
=
not
(
do_not_track
or
no_usage_stats
...
...
@@ -167,7 +169,7 @@ class UsageMessage:
# Metadata
self
.
log_time
=
_get_current_timestamp_ns
()
self
.
source
=
os
.
environ
.
get
(
"
VLLM_USAGE_SOURCE
"
,
"production"
)
self
.
source
=
envs
.
VLLM_USAGE_SOURCE
data
=
vars
(
self
)
if
extra_kvs
:
...
...
vllm/utils.py
View file @
1591c68f
import
asyncio
import
datetime
import
enum
import
gc
import
glob
import
os
import
socket
import
subprocess
import
tempfile
import
threading
import
uuid
import
warnings
from
collections
import
defaultdict
...
...
@@ -18,7 +21,8 @@ import psutil
import
torch
from
packaging.version
import
Version
,
parse
from
vllm.logger
import
init_logger
import
vllm.envs
as
envs
from
vllm.logger
import
enable_trace_function_call
,
init_logger
T
=
TypeVar
(
"T"
)
logger
=
init_logger
(
__name__
)
...
...
@@ -171,7 +175,7 @@ def get_vllm_instance_id():
Instance id represents an instance of the VLLM. All processes in the same
instance should have the same instance id.
"""
return
os
.
environ
.
get
(
"
VLLM_INSTANCE_ID
"
,
f
"vllm-instance-
{
random_uuid
()
}
"
)
return
envs
.
VLLM_INSTANCE_ID
or
f
"vllm-instance-
{
random_uuid
()
}
"
@
lru_cache
(
maxsize
=
None
)
...
...
@@ -222,18 +226,25 @@ def merge_async_iterators(
]
async
def
consumer
():
while
not
all
(
finished
)
or
not
queue
.
empty
():
item
=
await
queue
.
get
()
if
isinstance
(
item
,
Exception
):
raise
item
yield
item
try
:
while
not
all
(
finished
)
or
not
queue
.
empty
():
item
=
await
queue
.
get
()
if
isinstance
(
item
,
Exception
):
raise
item
yield
item
except
(
Exception
,
asyncio
.
CancelledError
)
as
e
:
for
task
in
_tasks
:
# NOTE: Pass the error msg in cancel()
# when only Python 3.9+ is supported.
task
.
cancel
()
raise
e
await
asyncio
.
gather
(
*
_tasks
)
return
consumer
()
def
get_ip
()
->
str
:
host_ip
=
os
.
environ
.
get
(
"
HOST_IP
"
)
host_ip
=
envs
.
VLLM_
HOST_IP
if
host_ip
:
return
host_ip
...
...
@@ -259,7 +270,8 @@ def get_ip() -> str:
warnings
.
warn
(
"Failed to get the IP address, using 0.0.0.0 by default."
"The value can be set by the environment variable HOST_IP."
,
"The value can be set by the environment variable"
" VLLM_HOST_IP or HOST_IP."
,
stacklevel
=
2
)
return
"0.0.0.0"
...
...
@@ -286,8 +298,9 @@ def get_open_port() -> int:
def
update_environment_variables
(
envs
:
Dict
[
str
,
str
]):
for
k
,
v
in
envs
.
items
():
if
k
in
os
.
environ
and
os
.
environ
[
k
]
!=
v
:
logger
.
warning
(
f
"Overwriting environment variable
{
k
}
"
f
"from '
{
os
.
environ
[
k
]
}
' to '
{
v
}
'"
)
logger
.
warning
(
"Overwriting environment variable %s "
"from '%s' to '%s'"
,
k
,
os
.
environ
[
k
],
v
)
os
.
environ
[
k
]
=
v
...
...
@@ -303,15 +316,16 @@ def cdiv(a: int, b: int) -> int:
@
lru_cache
(
maxsize
=
None
)
def
get_nvcc_cuda_version
()
->
Optional
[
Version
]:
cuda_home
=
os
.
environ
.
get
(
'
CUDA_HOME
'
)
cuda_home
=
envs
.
CUDA_HOME
if
not
cuda_home
:
cuda_home
=
'/usr/local/cuda'
if
os
.
path
.
isfile
(
cuda_home
+
'/bin/nvcc'
):
logger
.
info
(
f
'CUDA_HOME is not found in the environment. '
f
'Using
{
cuda_home
}
as CUDA_HOME.'
)
logger
.
info
(
'CUDA_HOME is not found in the environment. '
'Using %s as CUDA_HOME.'
,
cuda_home
)
else
:
logger
.
warning
(
f
'Not found nvcc in
{
cuda_home
}
. Skip cuda version check!'
)
logger
.
warning
(
'Not found nvcc in %s. Skip cuda version check!'
,
cuda_home
)
return
None
nvcc_output
=
subprocess
.
check_output
([
cuda_home
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
...
...
@@ -341,21 +355,9 @@ def _generate_random_fp8(
del
tensor_tmp
def
create_kv_caches_with_random
(
num_blocks
:
int
,
block_size
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
cache_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]],
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
def
get_kv_cache_torch_dtype
(
cache_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]],
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
)
->
torch
.
dtype
:
if
isinstance
(
cache_dtype
,
str
):
if
cache_dtype
==
"auto"
:
if
isinstance
(
model_dtype
,
str
):
...
...
@@ -374,6 +376,55 @@ def create_kv_caches_with_random(
torch_dtype
=
cache_dtype
else
:
raise
ValueError
(
f
"Invalid kv cache dtype:
{
cache_dtype
}
"
)
return
torch_dtype
def
create_kv_caches_with_random_flash
(
num_blocks
:
int
,
block_size
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
cache_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]],
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
assert
cache_dtype
!=
"fp8"
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
scale
=
head_size
**-
0.5
key_caches
,
value_caches
=
[],
[]
for
_
in
range
(
num_layers
):
key_value_cache
=
torch
.
empty
(
size
=
key_value_cache_shape
,
dtype
=
torch_dtype
,
device
=
device
)
key_value_cache
.
uniform_
(
-
scale
,
scale
)
key_caches
.
append
(
key_value_cache
[:,
0
])
value_caches
.
append
(
key_value_cache
[:,
1
])
return
key_caches
,
value_caches
def
create_kv_caches_with_random
(
num_blocks
:
int
,
block_size
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
cache_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]],
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
scale
=
head_size
**-
0.5
x
=
16
//
torch
.
tensor
([],
dtype
=
torch_dtype
).
element_size
()
...
...
@@ -569,7 +620,7 @@ def find_library(lib_name: str) -> str:
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
locs
=
[
line
.
split
()[
-
1
]
for
line
in
libs
.
splitlines
()
if
lib_name
in
line
]
# `LD_LIBRARY_PATH` searches the library in the user-defined paths
env_ld_library_path
=
os
.
get
env
(
"
LD_LIBRARY_PATH
"
)
env_ld_library_path
=
env
s
.
LD_LIBRARY_PATH
if
not
locs
and
env_ld_library_path
:
locs
=
[
os
.
path
.
join
(
dir
,
lib_name
)
...
...
@@ -582,22 +633,23 @@ def find_library(lib_name: str) -> str:
def
find_nccl_library
():
so_file
=
os
.
environ
.
get
(
"VLLM_NCCL_SO_PATH"
,
""
)
so_file
=
envs
.
VLLM_NCCL_SO_PATH
VLLM_CONFIG_ROOT
=
envs
.
VLLM_CONFIG_ROOT
# check if we have vllm-managed nccl
vllm_nccl_path
=
None
if
torch
.
version
.
cuda
is
not
None
:
cuda_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
path
=
os
.
path
.
expanduser
(
f
"
~/.config
/vllm/nccl/cu
{
cuda_major
}
/libnccl.so.*"
)
f
"
{
VLLM_CONFIG_ROOT
}
/vllm/nccl/cu
{
cuda_major
}
/libnccl.so.*"
)
files
=
glob
.
glob
(
path
)
vllm_nccl_path
=
files
[
0
]
if
files
else
None
# manually load the nccl library
if
so_file
:
logger
.
info
(
f
"Found nccl from environment variable VLLM_NCCL_SO_PATH=
{
so_file
}
"
)
"Found nccl from environment variable VLLM_NCCL_SO_PATH=
%s"
,
so_file
)
else
:
if
torch
.
version
.
cuda
is
not
None
:
so_file
=
vllm_nccl_path
or
find_library
(
"libnccl.so.2"
)
...
...
@@ -605,5 +657,21 @@ def find_nccl_library():
so_file
=
find_library
(
"librccl.so.1"
)
else
:
raise
ValueError
(
"NCCL only supports CUDA and ROCm backends."
)
logger
.
info
(
f
"Found nccl from library
{
so_file
}
"
)
logger
.
info
(
"Found nccl from library
%s"
,
so_file
)
return
so_file
def
enable_trace_function_call_for_thread
()
->
None
:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""
if
envs
.
VLLM_TRACE_FUNCTION
:
tmp_dir
=
tempfile
.
gettempdir
()
filename
=
(
f
"VLLM_TRACE_FUNCTION_for_process_
{
os
.
getpid
()
}
"
f
"_thread_
{
threading
.
get_ident
()
}
_"
f
"at_
{
datetime
.
datetime
.
now
()
}
.log"
).
replace
(
" "
,
"_"
)
log_path
=
os
.
path
.
join
(
tmp_dir
,
"vllm"
,
get_vllm_instance_id
(),
filename
)
os
.
makedirs
(
os
.
path
.
dirname
(
log_path
),
exist_ok
=
True
)
enable_trace_function_call
(
log_path
)
vllm/worker/cpu_model_runner.py
View file @
1591c68f
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
...
...
@@ -10,9 +10,8 @@ from vllm.distributed import broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
,
maybe_expand_dim
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
logger
=
init_logger
(
__name__
)
...
...
@@ -38,6 +37,8 @@ class CPUModelRunner:
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
# Currently, CPU worker doesn't support chunked prefill.
assert
self
.
scheduler_config
.
chunked_prefill_enabled
is
False
self
.
lora_config
=
lora_config
self
.
vision_language_config
=
vision_language_config
self
.
load_config
=
load_config
...
...
@@ -79,7 +80,7 @@ class CPUModelRunner:
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
prompt
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
...
...
@@ -91,15 +92,15 @@ class CPUModelRunner:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
computed_len
=
seq_data
.
get_num_computed_tokens
()
prompt
_len
=
len
(
prompt_tokens
)
seq
_len
=
len
(
prompt_tokens
)
prompt
_lens
.
append
(
prompt
_len
)
# Prompt token num
seq
_lens
.
append
(
seq
_len
)
# Prompt token num
input_tokens
.
extend
(
prompt_tokens
)
# Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
prompt
_len
)))
input_positions
.
extend
(
list
(
range
(
computed_len
,
seq
_len
)))
if
seq_group_metadata
.
multi_modal_data
:
multi_modal_input_list
.
append
(
...
...
@@ -108,15 +109,15 @@ class CPUModelRunner:
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0,
prompt
_len - sliding_window).
# where start_idx is max(0,
seq
_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
start_idx
=
max
(
0
,
prompt
_len
-
self
.
sliding_window
)
start_idx
=
max
(
0
,
seq
_len
-
self
.
sliding_window
)
for
i
in
range
(
computed_len
,
prompt
_len
):
for
i
in
range
(
computed_len
,
seq
_len
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
...
...
@@ -150,19 +151,19 @@ class CPUModelRunner:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
prompt_lens
=
prompt_lens
,
num_prefills
=
len
(
prompt_lens
),
seq_lens
=
seq_lens
,
seq_lens_tensor
=
None
,
max_seq_len
=
None
,
num_prefills
=
len
(
seq_lens
),
num_prefill_tokens
=
num_prompt_tokens
,
num_decode_tokens
=
0
,
prefill_metadata
=
None
,
decode_metadata
=
None
,
max_context_len
=
None
,
context_lens
=
None
,
block_tables
=
torch
.
tensor
([]),
slot_mapping
=
slot_mapping
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt
_lens
,
return
(
input_tokens
,
input_positions
,
attn_metadata
,
seq
_lens
,
multi_modal_input
)
def
_prepare_decode
(
...
...
@@ -173,7 +174,7 @@ class CPUModelRunner:
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
context
_lens
:
List
[
int
]
=
[]
seq
_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
...
...
@@ -191,9 +192,9 @@ class CPUModelRunner:
position
=
seq_len
-
1
input_positions
.
append
(
position
)
context
_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq
_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
context
_lens
.
append
(
context
_len
)
seq
_lens
.
append
(
seq
_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
...
...
@@ -207,7 +208,7 @@ class CPUModelRunner:
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
max_
context
_len
=
max
(
context
_lens
)
max_
seq
_len
=
max
(
seq
_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
...
...
@@ -218,9 +219,9 @@ class CPUModelRunner:
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
context_lens
=
torch
.
tensor
(
context
_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq
_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
block_tables
)
...
...
@@ -235,14 +236,14 @@ class CPUModelRunner:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
None
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_seq_len
=
max_seq_len
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
len
(
input_tokens
),
max_context_len
=
max_context_len
,
num_prefills
=
0
,
prefill_metadata
=
None
,
decode_metadata
=
None
,
context_lens
=
context_lens
,
block_tables
=
block_tables
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
...
...
@@ -252,99 +253,6 @@ class CPUModelRunner:
attn_metadata
,
)
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
)
->
SamplingMetadata
:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
if
seq_group_metadata
.
is_prompt
:
assert
len
(
seq_ids
)
==
1
subquery_len
=
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx
+=
subquery_len
-
1
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
(
(
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
))
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
subquery_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
subquery_len
-
1
)
selected_token_start_idx
+=
subquery_len
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
sampling_params
.
seed
)
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
num_seqs
))
selected_token_start_idx
+=
num_seqs
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
zip
(
range
(
categorized_sample_indices_start_idx
,
categorized_sample_indices_start_idx
+
num_seqs
),
range
(
categorized_sampled_token_indices_start_idx
,
categorized_sampled_token_indices_start_idx
+
num_seqs
)))
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sampled_token_indices_start_idx
+=
num_seqs
if
sampling_params
.
seed
is
not
None
:
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
selected_token_indices
=
torch
.
tensor
(
selected_token_indices
,
dtype
=
torch
.
long
)
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
torch
.
tensor
(
seq_ids
,
dtype
=
torch
.
int
),
2
,
2
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
generators
=
generators
,
)
return
sampling_metadata
def
prepare_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
...
@@ -357,15 +265,22 @@ class CPUModelRunner:
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt
_lens
,
(
input_tokens
,
input_positions
,
attn_metadata
,
seq
_lens
,
multi_modal_input
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt_lens
=
[]
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
seq_lens
=
[]
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens
,
self
.
device
,
pin_memory
=
False
)
# Broadcast the metadata.
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
...
...
@@ -385,11 +300,10 @@ class CPUModelRunner:
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_data
=
None
,
prompt
_lens
=
None
,
seq
_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
generators
=
None
,
perform_sampling
=
False
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
...
...
@@ -421,7 +335,7 @@ class CPUModelRunner:
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
s
ampling_metadata
.
perform_sampling
:
if
not
s
elf
.
is_driver_worker
:
return
None
# Sample the next token.
...
...
Prev
1
…
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment