Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99b471c2
Commit
99b471c2
authored
May 21, 2024
by
zhuwenwen
Browse files
merge v0.4.1
parents
1925d2e9
468d761b
Changes
336
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
392 additions
and
215 deletions
+392
-215
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+9
-18
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+5
-12
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+5
-12
vllm/model_executor/models/xverse.py
vllm/model_executor/models/xverse.py
+5
-12
vllm/model_executor/parallel_utils/README.md
vllm/model_executor/parallel_utils/README.md
+0
-1
vllm/model_executor/parallel_utils/utils.py
vllm/model_executor/parallel_utils/utils.py
+0
-48
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+4
-0
vllm/outputs.py
vllm/outputs.py
+3
-1
vllm/sampling_params.py
vllm/sampling_params.py
+45
-2
vllm/sequence.py
vllm/sequence.py
+75
-16
vllm/spec_decode/__init__.py
vllm/spec_decode/__init__.py
+0
-0
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+49
-22
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+4
-4
vllm/spec_decode/metrics.py
vllm/spec_decode/metrics.py
+24
-8
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+23
-17
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+104
-31
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+26
-0
vllm/test_utils.py
vllm/test_utils.py
+6
-7
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+4
-3
vllm/transformers_utils/configs/dbrx.py
vllm/transformers_utils/configs/dbrx.py
+1
-1
No files found.
vllm/model_executor/models/qwen2_moe.py
View file @
99b471c2
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -30,6 +30,9 @@ from torch import nn
...
@@ -30,6 +30,9 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
@@ -43,13 +46,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -43,13 +46,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -367,6 +365,8 @@ class Qwen2MoeModel(nn.Module):
...
@@ -367,6 +365,8 @@ class Qwen2MoeModel(nn.Module):
class
Qwen2MoeForCausalLM
(
nn
.
Module
):
class
Qwen2MoeForCausalLM
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
...
@@ -405,11 +405,7 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -405,11 +405,7 @@ class Qwen2MoeForCausalLM(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -420,12 +416,7 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -420,12 +416,7 @@ class Qwen2MoeForCausalLM(nn.Module):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
,
fall_back_to_pt
=
False
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
...
...
vllm/model_executor/models/stablelm.py
View file @
99b471c2
...
@@ -19,13 +19,14 @@
...
@@ -19,13 +19,14 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -36,11 +37,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -36,11 +37,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -263,11 +261,7 @@ class StablelmForCausalLM(nn.Module):
...
@@ -263,11 +261,7 @@ class StablelmForCausalLM(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -277,8 +271,7 @@ class StablelmForCausalLM(nn.Module):
...
@@ -277,8 +271,7 @@ class StablelmForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
(
"rotary_emb.cos_cached"
in
name
if
(
"rotary_emb.cos_cached"
in
name
...
...
vllm/model_executor/models/starcoder2.py
View file @
99b471c2
...
@@ -18,13 +18,14 @@
...
@@ -18,13 +18,14 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" PyTorch Starcoder2 model."""
""" PyTorch Starcoder2 model."""
from
typing
import
List
,
Optional
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Starcoder2Config
from
transformers
import
Starcoder2Config
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
...
@@ -35,11 +36,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -35,11 +36,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -275,11 +273,7 @@ class Starcoder2ForCausalLM(nn.Module):
...
@@ -275,11 +273,7 @@ class Starcoder2ForCausalLM(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -288,8 +282,7 @@ class Starcoder2ForCausalLM(nn.Module):
...
@@ -288,8 +282,7 @@ class Starcoder2ForCausalLM(nn.Module):
]
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
...
vllm/model_executor/models/xverse.py
View file @
99b471c2
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Xverse model compatible with HuggingFace weights."""
"""Inference-only Xverse model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -28,6 +28,7 @@ from transformers import PretrainedConfig
...
@@ -28,6 +28,7 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
...
@@ -39,11 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -39,11 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -332,11 +330,7 @@ class XverseForCausalLM(nn.Module):
...
@@ -332,11 +330,7 @@ class XverseForCausalLM(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
...
@@ -345,8 +339,7 @@ class XverseForCausalLM(nn.Module):
...
@@ -345,8 +339,7 @@ class XverseForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
(
"rotary_emb.inv_freq"
in
name
if
(
"rotary_emb.inv_freq"
in
name
or
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
or
"rotary_emb.sin_cached"
in
name
):
...
...
vllm/model_executor/parallel_utils/README.md
deleted
100644 → 0
View file @
1925d2e9
The files in this folder are ported from
[
Megatron-LM
](
https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core
)
. We only keep the codes that are used in inference.
\ No newline at end of file
vllm/model_executor/parallel_utils/utils.py
deleted
100644 → 0
View file @
1925d2e9
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Sequence
import
torch
def
ensure_divisibility
(
numerator
,
denominator
):
"""Ensure that numerator is divisible by the denominator."""
assert
numerator
%
denominator
==
0
,
"{} is not divisible by {}"
.
format
(
numerator
,
denominator
)
def
divide
(
numerator
,
denominator
):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility
(
numerator
,
denominator
)
return
numerator
//
denominator
def
split_tensor_along_last_dim
(
tensor
:
torch
.
Tensor
,
num_partitions
:
int
,
contiguous_split_chunks
:
bool
=
False
,
)
->
Sequence
[
torch
.
Tensor
]:
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim
=
tensor
.
dim
()
-
1
last_dim_size
=
divide
(
tensor
.
size
()[
last_dim
],
num_partitions
)
# Split.
tensor_list
=
torch
.
split
(
tensor
,
last_dim_size
,
dim
=
last_dim
)
# NOTE: torch.split does not create contiguous tensors by default.
if
contiguous_split_chunks
:
return
tuple
(
chunk
.
contiguous
()
for
chunk
in
tensor_list
)
return
tensor_list
vllm/model_executor/sampling_metadata.py
View file @
99b471c2
...
@@ -113,6 +113,8 @@ class SamplingTensors:
...
@@ -113,6 +113,8 @@ class SamplingTensors:
get_num_triton_sampler_splits
(
vocab_size
))
get_num_triton_sampler_splits
(
vocab_size
))
sample_indices_start_idx
=
0
sample_indices_start_idx
=
0
assert
sampling_metadata
.
seq_groups
is
not
None
assert
sampling_metadata
.
seq_data
is
not
None
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
temperature
=
sampling_params
.
temperature
...
@@ -147,6 +149,7 @@ class SamplingTensors:
...
@@ -147,6 +149,7 @@ class SamplingTensors:
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get
# For tokens in the prompt that we only need to get
# their logprobs
# their logprobs
assert
sampling_metadata
.
prompt_lens
is
not
None
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
...
@@ -172,6 +175,7 @@ class SamplingTensors:
...
@@ -172,6 +175,7 @@ class SamplingTensors:
is_prompt
=
i
<
sampling_metadata
.
num_prompts
is_prompt
=
i
<
sampling_metadata
.
num_prompts
if
is_prompt
:
if
is_prompt
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
prompt_best_of
.
append
(
sampling_params
.
best_of
)
assert
sampling_metadata
.
prompt_lens
is
not
None
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
if
sampling_params
.
prompt_logprobs
is
not
None
:
...
...
vllm/outputs.py
View file @
99b471c2
...
@@ -112,8 +112,10 @@ class RequestOutput:
...
@@ -112,8 +112,10 @@ class RequestOutput:
# always has the logprobs of the sampled tokens even if the
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
# logprobs are not requested.
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
text_buffer_length
=
seq_group
.
sampling_params
.
output_text_buffer_length
outputs
=
[
outputs
=
[
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
output_text
,
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
get_output_text_to_return
(
text_buffer_length
),
seq
.
get_output_token_ids
(),
seq
.
get_output_token_ids
(),
seq
.
get_cumulative_logprob
(),
seq
.
get_cumulative_logprob
(),
seq
.
output_logprobs
if
include_logprobs
else
None
,
seq
.
output_logprobs
if
include_logprobs
else
None
,
...
...
vllm/sampling_params.py
View file @
99b471c2
...
@@ -2,9 +2,11 @@
...
@@ -2,9 +2,11 @@
import
copy
import
copy
from
enum
import
IntEnum
from
enum
import
IntEnum
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
pydantic
import
Field
from
typing_extensions
import
Annotated
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
...
@@ -88,11 +90,15 @@ class SamplingParams:
...
@@ -88,11 +90,15 @@ class SamplingParams:
log probability of the sampled token, so there may be up to
log probability of the sampled token, so there may be up to
`logprobs+1` elements in the response.
`logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output.
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
logits_processors: List of functions that modify logits based on
previously generated tokens.
previously generated tokens.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
"""
"""
def
__init__
(
def
__init__
(
...
@@ -118,9 +124,11 @@ class SamplingParams:
...
@@ -118,9 +124,11 @@ class SamplingParams:
min_tokens
:
int
=
0
,
min_tokens
:
int
=
0
,
logprobs
:
Optional
[
int
]
=
None
,
logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
detokenize
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
,
)
->
None
:
)
->
None
:
self
.
n
=
n
self
.
n
=
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
...
@@ -150,10 +158,22 @@ class SamplingParams:
...
@@ -150,10 +158,22 @@ class SamplingParams:
self
.
min_tokens
=
min_tokens
self
.
min_tokens
=
min_tokens
self
.
logprobs
=
logprobs
self
.
logprobs
=
logprobs
self
.
prompt_logprobs
=
prompt_logprobs
self
.
prompt_logprobs
=
prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
self
.
detokenize
=
detokenize
self
.
skip_special_tokens
=
skip_special_tokens
self
.
skip_special_tokens
=
skip_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
logits_processors
=
logits_processors
self
.
logits_processors
=
logits_processors
self
.
include_stop_str_in_output
=
include_stop_str_in_output
self
.
include_stop_str_in_output
=
include_stop_str_in_output
self
.
truncate_prompt_tokens
=
truncate_prompt_tokens
# Number of characters to hold back for stop string evaluation
# until sequence is finished.
if
self
.
stop
and
not
include_stop_str_in_output
:
self
.
output_text_buffer_length
=
max
(
len
(
s
)
for
s
in
self
.
stop
)
-
1
else
:
self
.
output_text_buffer_length
=
0
self
.
_verify_args
()
self
.
_verify_args
()
if
self
.
use_beam_search
:
if
self
.
use_beam_search
:
self
.
_verify_beam_search
()
self
.
_verify_beam_search
()
...
@@ -210,6 +230,16 @@ class SamplingParams:
...
@@ -210,6 +230,16 @@ class SamplingParams:
if
self
.
prompt_logprobs
is
not
None
and
self
.
prompt_logprobs
<
0
:
if
self
.
prompt_logprobs
is
not
None
and
self
.
prompt_logprobs
<
0
:
raise
ValueError
(
f
"prompt_logprobs must be non-negative, got "
raise
ValueError
(
f
"prompt_logprobs must be non-negative, got "
f
"
{
self
.
prompt_logprobs
}
."
)
f
"
{
self
.
prompt_logprobs
}
."
)
if
(
self
.
truncate_prompt_tokens
is
not
None
and
self
.
truncate_prompt_tokens
<
1
):
raise
ValueError
(
f
"truncate_prompt_tokens must be >= 1, "
f
"got
{
self
.
truncate_prompt_tokens
}
"
)
if
any
(
not
stop_str
for
stop_str
in
self
.
stop
):
raise
ValueError
(
"stop cannot contain an empty string."
)
if
self
.
stop
and
not
self
.
detokenize
:
raise
ValueError
(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop."
)
def
_verify_beam_search
(
self
)
->
None
:
def
_verify_beam_search
(
self
)
->
None
:
if
self
.
best_of
==
1
:
if
self
.
best_of
==
1
:
...
@@ -241,6 +271,18 @@ class SamplingParams:
...
@@ -241,6 +271,18 @@ class SamplingParams:
raise
ValueError
(
"best_of must be 1 when using greedy sampling."
raise
ValueError
(
"best_of must be 1 when using greedy sampling."
f
"Got
{
self
.
best_of
}
."
)
f
"Got
{
self
.
best_of
}
."
)
def
update_from_generation_config
(
self
,
generation_config
:
Dict
[
str
,
Any
])
->
None
:
"""Update if there are non-default values from generation_config"""
# Update eos_token_id for generation
if
eos_ids
:
=
generation_config
.
get
(
"eos_token_id"
):
# it can be either int or list of int
if
isinstance
(
eos_ids
,
int
):
eos_ids
=
[
eos_ids
]
original_stop_token_ids
=
set
(
self
.
stop_token_ids
)
original_stop_token_ids
.
update
(
eos_ids
)
self
.
stop_token_ids
=
list
(
original_stop_token_ids
)
@
cached_property
@
cached_property
def
sampling_type
(
self
)
->
SamplingType
:
def
sampling_type
(
self
)
->
SamplingType
:
if
self
.
use_beam_search
:
if
self
.
use_beam_search
:
...
@@ -290,4 +332,5 @@ class SamplingParams:
...
@@ -290,4 +332,5 @@ class SamplingParams:
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
, "
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
, "
"spaces_between_special_tokens="
"spaces_between_special_tokens="
f
"
{
self
.
spaces_between_special_tokens
}
)"
)
f
"
{
self
.
spaces_between_special_tokens
}
, "
f
"truncate_prompt_tokens=
{
self
.
truncate_prompt_tokens
}
)"
)
vllm/sequence.py
View file @
99b471c2
...
@@ -69,6 +69,11 @@ class SequenceStatus(enum.Enum):
...
@@ -69,6 +69,11 @@ class SequenceStatus(enum.Enum):
return
finish_reason
return
finish_reason
class
SequenceStage
(
enum
.
Enum
):
PREFILL
=
enum
.
auto
()
DECODE
=
enum
.
auto
()
@
dataclass
@
dataclass
class
RequestMetrics
:
class
RequestMetrics
:
"""Metrics associated with a request.
"""Metrics associated with a request.
...
@@ -115,6 +120,7 @@ class SequenceData:
...
@@ -115,6 +120,7 @@ class SequenceData:
self
.
cumulative_logprob
=
0.0
self
.
cumulative_logprob
=
0.0
# The number of tokens that are computed (that run against the model).
# The number of tokens that are computed (that run against the model).
self
.
_num_computed_tokens
=
0
self
.
_num_computed_tokens
=
0
self
.
_stage
:
SequenceStage
=
SequenceStage
.
PREFILL
def
append_token_id
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
def
append_token_id
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
self
.
output_token_ids
.
append
(
token_id
)
self
.
output_token_ids
.
append
(
token_id
)
...
@@ -136,19 +142,25 @@ class SequenceData:
...
@@ -136,19 +142,25 @@ class SequenceData:
"""Return the number of prefill tokens that are already computed."""
"""Return the number of prefill tokens that are already computed."""
return
self
.
_num_computed_tokens
return
self
.
_num_computed_tokens
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
)
->
int
:
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
"""Update number of tokens computed so far."""
"""Update number of tokens computed so far."""
self
.
_num_computed_tokens
+=
num_new_computed_tokens
self
.
_num_computed_tokens
+=
num_new_computed_tokens
assert
self
.
_num_computed_tokens
<=
self
.
get_len
(),
(
self
.
_num_computed_tokens
,
self
.
get_len
())
# If all tokens are computed, it means it is in decoding phase.
if
self
.
get_num_uncomputed_tokens
()
==
0
:
self
.
_stage
=
SequenceStage
.
DECODE
def
reset_
num_computed_tokens
(
self
)
->
None
:
def
reset_
state_for_recompute
(
self
)
->
None
:
"""Reset the number of computed tokens from this sequence. It is
"""Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from
supposed to be called when a sequence needs to be started from
the beginning again (e.g., sequence is preempted).
the beginning again (e.g., sequence is preempted).
"""
"""
self
.
_num_computed_tokens
=
0
self
.
_num_computed_tokens
=
0
self
.
_stage
=
SequenceStage
.
PREFILL
def
get_num_uncomputed_tokens
(
self
)
->
int
:
def
get_num_uncomputed_tokens
(
self
)
->
int
:
"""Return the number of prefil tokens that are not computed."""
"""Return the number of prefil
l
tokens that are not computed."""
# we use `get_len()` which includes prompt_len + output_len instead
# we use `get_len()` which includes prompt_len + output_len instead
# of prompt_len here. This is because during recompute we need to
# of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output.
# prefill for both prompt and output.
...
@@ -159,12 +171,16 @@ class SequenceData:
...
@@ -159,12 +171,16 @@ class SequenceData:
return
self
.
prompt_token_ids
[
-
1
]
return
self
.
prompt_token_ids
[
-
1
]
return
self
.
output_token_ids
[
-
1
]
return
self
.
output_token_ids
[
-
1
]
def
get_prompt_token_ids
(
self
)
->
int
:
def
get_prompt_token_ids
(
self
)
->
List
[
int
]
:
return
self
.
prompt_token_ids
return
self
.
prompt_token_ids
def
get_output_token_ids
(
self
)
->
int
:
def
get_output_token_ids
(
self
)
->
List
[
int
]
:
return
self
.
output_token_ids
return
self
.
output_token_ids
@
property
def
stage
(
self
)
->
SequenceStage
:
return
self
.
_stage
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceData("
return
(
f
"SequenceData("
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
...
@@ -219,6 +235,12 @@ class Sequence:
...
@@ -219,6 +235,12 @@ class Sequence:
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
def
get_output_text_to_return
(
self
,
buffer_length
:
int
):
# We return the full output text if the sequence is finished.
truncate
=
buffer_length
and
not
self
.
is_finished
()
return
self
.
output_text
[:
-
buffer_length
]
if
truncate
else
(
self
.
output_text
)
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
# TODO This can produce incorrect hash when block size > prompt size
# TODO This can produce incorrect hash when block size > prompt size
...
@@ -234,7 +256,7 @@ class Sequence:
...
@@ -234,7 +256,7 @@ class Sequence:
def
reset_state_for_recompute
(
self
):
def
reset_state_for_recompute
(
self
):
"""Reset the sequence states for recomputation."""
"""Reset the sequence states for recomputation."""
self
.
data
.
reset_
num_computed_tokens
()
self
.
data
.
reset_
state_for_recompute
()
def
_append_logical_block
(
self
)
->
None
:
def
_append_logical_block
(
self
)
->
None
:
block
=
LogicalTokenBlock
(
block
=
LogicalTokenBlock
(
...
@@ -320,6 +342,20 @@ class Sequence:
...
@@ -320,6 +342,20 @@ class Sequence:
new_seq
.
seq_id
=
new_seq_id
new_seq
.
seq_id
=
new_seq_id
return
new_seq
return
new_seq
def
get_num_new_tokens
(
self
)
->
int
:
"""Get the number of new tokens to be computed.
Returns:
The new number of tokens to be computed. I.e., 1 for decode, or
the remaining prompt size for prefill.
"""
if
self
.
data
.
stage
==
SequenceStage
.
DECODE
:
return
1
return
self
.
data
.
get_num_uncomputed_tokens
()
def
is_prefill
(
self
)
->
bool
:
return
self
.
data
.
stage
==
SequenceStage
.
PREFILL
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"Sequence(seq_id=
{
self
.
seq_id
}
, "
return
(
f
"Sequence(seq_id=
{
self
.
seq_id
}
, "
f
"status=
{
self
.
status
.
name
}
, "
f
"status=
{
self
.
status
.
name
}
, "
...
@@ -331,7 +367,7 @@ class SequenceGroupState:
...
@@ -331,7 +367,7 @@ class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
# torch.Generator used in seeded sampling
generator
:
Optional
=
None
generator
:
Optional
=
None
# type: ignore
class
MultiModalData
:
class
MultiModalData
:
...
@@ -461,16 +497,22 @@ class SequenceGroup:
...
@@ -461,16 +497,22 @@ class SequenceGroup:
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
"""Update number of tokens computed so far."""
"""Update number of tokens computed so far."""
for
seq
in
self
.
seqs_dict
.
values
():
for
seq
in
self
.
seqs_dict
.
values
():
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
if
not
seq
.
is_finished
():
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
def
get_num_uncomputed_tokens
(
self
)
->
int
:
def
get_num_uncomputed_tokens
(
self
)
->
int
:
# All sequences in the group should have the same prompt, so the
num_uncomputed_tokens
=
0
# number of unfinished prefill tokens are the same across all
for
seq
in
self
.
get_seqs
():
# sequences.
if
not
seq
.
is_finished
():
return
list
(
num_uncomputed_tokens
+=
seq
.
data
.
get_num_uncomputed_tokens
()
self
.
seqs_dict
.
values
())[
0
].
data
.
get_
num_uncomputed_tokens
()
return
num_uncomputed_tokens
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
# Optimization. We don't need to call get_seqs if we don't need to
# filter by states.
if
status
is
None
:
return
len
(
self
.
seqs_dict
)
return
len
(
self
.
get_seqs
(
status
))
return
len
(
self
.
get_seqs
(
status
))
def
num_unfinished_seqs
(
self
)
->
int
:
def
num_unfinished_seqs
(
self
)
->
int
:
...
@@ -497,6 +539,10 @@ class SequenceGroup:
...
@@ -497,6 +539,10 @@ class SequenceGroup:
def
is_finished
(
self
)
->
bool
:
def
is_finished
(
self
)
->
bool
:
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
get_seqs
())
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
get_seqs
())
def
is_prefill
(
self
)
->
bool
:
# Every sequences should be in the same stage.
return
self
.
get_seqs
()[
0
].
is_prefill
()
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
f
"sampling_params=
{
self
.
sampling_params
}
, "
f
"sampling_params=
{
self
.
sampling_params
}
, "
...
@@ -513,8 +559,8 @@ class SequenceGroupMetadata:
...
@@ -513,8 +559,8 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
block_tables: The block tables. (Seq id -> list of physical block
numbers)
numbers)
token_chunk_size: The number of tokens to be processed
. None if
token_chunk_size: The number of tokens to be processed
(per sequence).
chunking is not required.
None if
chunking is not required.
state: Internal state tied to this sequence group.
state: Internal state tied to this sequence group.
lora_request: LoRA request.
lora_request: LoRA request.
multi_modal_data: Multi modal data.
multi_modal_data: Multi modal data.
...
@@ -555,7 +601,7 @@ class SequenceGroupMetadata:
...
@@ -555,7 +601,7 @@ class SequenceGroupMetadata:
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
@
property
@
property
def
token_chunk_size
(
self
)
->
int
:
def
token_chunk_size
(
self
)
->
Optional
[
int
]
:
"""Return the number of tokens to be processed (chunk size)."""
"""Return the number of tokens to be processed (chunk size)."""
return
self
.
_token_chunk_size
return
self
.
_token_chunk_size
...
@@ -649,3 +695,16 @@ class SamplerOutput:
...
@@ -649,3 +695,16 @@ class SamplerOutput:
def
__eq__
(
self
,
other
:
object
):
def
__eq__
(
self
,
other
:
object
):
return
isinstance
(
other
,
return
isinstance
(
other
,
self
.
__class__
)
and
self
.
outputs
==
other
.
outputs
self
.
__class__
)
and
self
.
outputs
==
other
.
outputs
def
__repr__
(
self
)
->
str
:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr
=
(
"None"
if
self
.
sampled_token_probs
is
None
else
self
.
sampled_token_probs
.
shape
)
sampled_token_ids_repr
=
(
"None"
if
self
.
sampled_token_ids
is
None
else
self
.
sampled_token_ids
.
shape
)
return
(
f
"SamplerOutput(outputs=
{
self
.
outputs
}
, "
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
vllm/spec_decode/__init__.py
0 → 100644
View file @
99b471c2
vllm/spec_decode/batch_expansion.py
View file @
99b471c2
...
@@ -9,7 +9,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
...
@@ -9,7 +9,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
sampler_output_to_torch
,
sampler_output_to_torch
,
split_batch_by_proposal_len
)
split_batch_by_proposal_len
)
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
_base
import
Worker
Base
SeqId
=
int
SeqId
=
int
TargetSeqId
=
int
TargetSeqId
=
int
...
@@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
of topk/tree.
of topk/tree.
"""
"""
def
__init__
(
self
,
scorer_worker
:
Worker
,
device
:
str
,
vocab_size
:
int
):
def
__init__
(
self
,
scorer_worker
:
WorkerBase
,
device
:
str
,
vocab_size
:
int
):
self
.
_scorer_worker
=
scorer_worker
self
.
_scorer_worker
=
scorer_worker
self
.
_device
=
device
self
.
_device
=
device
self
.
_vocab_size
=
vocab_size
self
.
_vocab_size
=
vocab_size
...
@@ -71,10 +72,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -71,10 +72,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
proposal_lens_list
=
proposals
.
proposal_lens
.
tolist
()
proposal_lens_list
=
proposals
.
proposal_lens
.
tolist
()
proposal_token_ids_list
=
proposals
.
proposal_token_ids
.
tolist
()
proposal_token_ids_list
=
proposals
.
proposal_token_ids
.
tolist
()
# Filter the list to ignore -1 proposals.
proposal_token_ids_list_without_skips
=
[
proposals
for
proposals
in
proposal_token_ids_list
if
-
1
not
in
proposals
]
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
=
self
.
_expand_batch
(
num_scoring_tokens
)
=
self
.
_expand_batch
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
proposal_token_ids_list
=
proposal_token_ids_list
,
proposal_token_ids_list
=
proposal_token_ids_list
_without_skips
,
proposal_lens_list
=
proposal_lens_list
,
proposal_lens_list
=
proposal_lens_list
,
)
)
...
@@ -83,10 +90,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -83,10 +90,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
return_python_output
=
False
)
)
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
all_tokens
,
all_probs
=
self
.
_contract_batch
(
all_tokens
,
all_probs
=
self
.
_contract_batch
(
original
_bs
=
len
(
seq_group_metadata_list
),
contracted
_bs
=
len
(
seq_group_metadata_list
),
target_sampler_output
=
target_sampler_output
,
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
num_scoring_tokens
=
num_scoring_tokens
,
...
@@ -103,7 +112,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -103,7 +112,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def
_expand_batch
(
def
_expand_batch
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_token_ids_list
:
List
[
TokenId
],
proposal_token_ids_list
:
List
[
List
[
TokenId
]
]
,
proposal_lens_list
:
List
[
int
],
proposal_lens_list
:
List
[
int
],
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
SequenceGroupMetadata
],
int
]:
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
SequenceGroupMetadata
],
int
]:
"""Given the input sequences and potentially multiple corresponding
"""Given the input sequences and potentially multiple corresponding
...
@@ -125,14 +134,21 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -125,14 +134,21 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
select_proposal_len_zero
=
True
)
select_proposal_len_zero
=
True
)
target_seq_group_metadata_list
=
self
.
_create_scoring_model_input
(
target_seq_group_metadata_list
=
self
.
_create_scoring_model_input
(
spec_seqs
,
proposal_token_ids_list
)
seq_group_metadata_list
=
spec_seqs
,
proposal_token_ids
=
proposal_token_ids_list
,
# NOTE: We determine the seq ids in the expanded batch using the
# full seq_group_metadata_list, instead of only spec_seqs.
target_seq_ids_iter
=
self
.
_create_target_seq_id_iterator
(
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)),
)
num_scoring_tokens
=
len
(
target_seq_group_metadata_list
)
num_scoring_tokens
=
len
(
target_seq_group_metadata_list
)
target_seq_group_metadata_list
.
extend
(
non_spec_seqs
)
target_seq_group_metadata_list
.
extend
(
non_spec_seqs
)
return
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
return
(
spec_indices
,
non_spec_indices
,
target_seq_group_metadata_list
,
num_scoring_tokens
)
num_scoring_tokens
)
def
_contract_batch
(
self
,
original
_bs
:
int
,
def
_contract_batch
(
self
,
contracted
_bs
:
int
,
target_sampler_output
:
List
[
SamplerOutput
],
target_sampler_output
:
List
[
SamplerOutput
],
proposals
:
SpeculativeProposals
,
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
...
@@ -141,6 +157,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -141,6 +157,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Contract the expanded batch back into its original size.
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
This maps the scores of speculative tokens back to their original
sequences.
sequences.
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
"""
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
non_spec_target_probs
)
=
self
.
_split_scoring_output
(
non_spec_target_probs
)
=
self
.
_split_scoring_output
(
...
@@ -148,25 +167,31 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -148,25 +167,31 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# Map distinct sequences used to score each token
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
batch_size
,
k
=
proposals
.
proposal_token_ids
.
shape
expanded_batch_size
,
k
=
proposals
.
proposal_token_ids
.
shape
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
non_spec_expanded_bs
,
_
=
non_spec_target_token_ids
.
shape
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
target_token_ids
=
target_token_ids
.
squeeze
().
reshape
(
target_token_ids
=
target_token_ids
.
squeeze
().
reshape
(
batch_size
,
k
+
1
)
spec_expanded_bs
,
k
+
1
)
target_probs
=
target_probs
.
squeeze
().
reshape
(
batch_size
,
k
+
1
,
target_probs
=
target_probs
.
squeeze
().
reshape
(
spec_expanded_bs
,
k
+
1
,
self
.
_vocab_size
)
self
.
_vocab_size
)
all_tokens
=
torch
.
full
(
size
=
(
original
_bs
,
k
+
1
),
all_tokens
=
torch
.
full
(
size
=
(
contracted
_bs
,
k
+
1
),
fill_value
=-
1
,
fill_value
=-
1
,
device
=
self
.
_device
,
device
=
self
.
_device
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
all_probs
=
torch
.
zeros
(
original
_bs
,
all_probs
=
torch
.
zeros
(
contracted
_bs
,
k
+
1
,
k
+
1
,
self
.
_vocab_size
,
self
.
_vocab_size
,
device
=
self
.
_device
,
device
=
self
.
_device
,
dtype
=
torch
.
float32
)
dtype
=
torch
.
float32
)
if
non_spec_indices
:
if
non_spec_indices
:
all_tokens
[
non_spec_indices
,
0
]
=
non_spec_target_token_ids
all_tokens
[
non_spec_indices
,
:
1
]
=
non_spec_target_token_ids
all_probs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_probs
all_probs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_probs
if
spec_indices
:
if
spec_indices
:
...
@@ -176,20 +201,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -176,20 +201,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return
all_tokens
,
all_probs
return
all_tokens
,
all_probs
def
_create_scoring_model_input
(
def
_create_scoring_model_input
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_token_ids
:
List
[
List
[
TokenId
]],
# shape: [batch_size, k]
proposal_token_ids
:
List
[
List
[
TokenId
]],
# shape: [batch_size, k]
target_seq_ids_iter
:
Iterator
[
TargetSeqId
],
)
->
List
[
SequenceGroupMetadata
]:
)
->
List
[
SequenceGroupMetadata
]:
"""Given the original input sequences and proposed tokens from the draft
"""Given the original input sequences and proposed tokens from the draft
model, create a list of target sequences that can be used for scoring.
model, create a list of target sequences that can be used for scoring.
target_seq_ids_iter provides sequence ids for the expanded batch,
fulfilling the requirement that no seq id in the expanded batch is equal
to the seq id in the original batch.
"""
"""
if
not
seq_group_metadata_list
:
if
not
seq_group_metadata_list
:
return
[]
return
[]
target_seq_ids_iter
=
self
.
_create_target_seq_id_iterator
(
get_all_seq_ids
(
seq_group_metadata_list
))
target_seq_group_metadata
=
list
(
target_seq_group_metadata
=
list
(
chain
.
from_iterable
(
chain
.
from_iterable
(
self
.
_create_target_seq_group_metadata
(
self
.
_create_target_seq_group_metadata
(
...
@@ -205,7 +232,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -205,7 +232,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def
_create_target_seq_group_metadata
(
def
_create_target_seq_group_metadata
(
self
,
self
,
input_seq_group_metadata
:
SequenceGroupMetadata
,
input_seq_group_metadata
:
SequenceGroupMetadata
,
proposal_token_ids
:
List
[
TokenId
],
# shape: [batch_size, k]
proposal_token_ids
:
List
[
List
[
TokenId
]
]
,
# shape: [batch_size, k]
batch_index
:
int
,
batch_index
:
int
,
target_seq_ids_iter
:
Iterator
[
TargetSeqId
],
target_seq_ids_iter
:
Iterator
[
TargetSeqId
],
)
->
List
[
SequenceGroupMetadata
]:
)
->
List
[
SequenceGroupMetadata
]:
...
@@ -347,7 +374,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -347,7 +374,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
[0, 1, 2]
[0, 1, 2]
[0, 1, 2, 3]
[0, 1, 2, 3]
"""
"""
empty_token_ids
=
[]
empty_token_ids
:
List
[
TokenId
]
=
[]
token_ids_to_score
=
[
empty_token_ids
]
token_ids_to_score
=
[
empty_token_ids
]
token_ids_to_score
.
extend
([
token_ids_to_score
.
extend
([
...
...
vllm/spec_decode/interfaces.py
View file @
99b471c2
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
import
torch
import
torch
...
@@ -24,9 +24,9 @@ class SpeculativeProposals:
...
@@ -24,9 +24,9 @@ class SpeculativeProposals:
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
f
"SpeculativeProposals("
return
(
f
"SpeculativeProposals("
f
"proposal_token_ids=
{
self
.
proposal_token_ids
.
shape
}
, "
f
"proposal_token_ids=
{
self
.
proposal_token_ids
}
, "
f
"proposal_probs=
{
self
.
proposal_probs
.
shape
}
, "
f
"proposal_probs=
{
self
.
proposal_probs
.
shape
}
, "
f
"proposal_lens=
{
self
.
proposal_lens
.
shape
}
)"
)
f
"proposal_lens=
{
self
.
proposal_lens
}
)"
)
@
dataclass
@
dataclass
...
@@ -73,5 +73,5 @@ class SpeculativeScorer(ABC):
...
@@ -73,5 +73,5 @@ class SpeculativeScorer(ABC):
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
k
:
int
,
k
:
int
,
proposals
:
SpeculativeProposals
,
proposals
:
SpeculativeProposals
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
SpeculativeScores
:
raise
NotImplementedError
raise
NotImplementedError
vllm/spec_decode/metrics.py
View file @
99b471c2
...
@@ -112,6 +112,7 @@ class AsyncMetricsCollector:
...
@@ -112,6 +112,7 @@ class AsyncMetricsCollector:
Returns a CUDA event recording when the copy is complete.
Returns a CUDA event recording when the copy is complete.
"""
"""
assert
self
.
_copy_stream
is
not
None
self
.
_copy_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
_copy_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_copy_stream
):
with
torch
.
cuda
.
stream
(
self
.
_copy_stream
):
...
@@ -146,15 +147,16 @@ class AsyncMetricsCollector:
...
@@ -146,15 +147,16 @@ class AsyncMetricsCollector:
emitted_tokens
=
self
.
_aggregate_num_emitted_tokens
.
item
()
emitted_tokens
=
self
.
_aggregate_num_emitted_tokens
.
item
()
draft_tokens
=
self
.
_aggregate_num_draft_tokens
draft_tokens
=
self
.
_aggregate_num_draft_tokens
num_possible_tokens
=
self
.
get_max_num_accepted_tokens
(
draft_tokens
,
k
)
max_num_emitted_tokens
=
self
.
get_max_num_emitted_tokens
(
draft_tokens
,
k
)
if
draft_tokens
>
0
:
if
draft_tokens
>
0
:
draft_acceptance_rate
=
accepted_tokens
/
draft_tokens
draft_acceptance_rate
=
accepted_tokens
/
draft_tokens
else
:
else
:
draft_acceptance_rate
=
float
(
"nan"
)
draft_acceptance_rate
=
float
(
"nan"
)
if
num_possible
_tokens
>
0
:
if
max_num_emitted
_tokens
>
0
:
system_efficiency
=
emitted_tokens
/
num_possible
_tokens
system_efficiency
=
emitted_tokens
/
max_num_emitted
_tokens
else
:
else
:
system_efficiency
=
float
(
"nan"
)
system_efficiency
=
float
(
"nan"
)
...
@@ -168,8 +170,22 @@ class AsyncMetricsCollector:
...
@@ -168,8 +170,22 @@ class AsyncMetricsCollector:
)
)
@
staticmethod
@
staticmethod
def
get_max_num_accepted_tokens
(
draft_tokens
:
int
,
k
:
int
)
->
int
:
def
get_max_num_emitted_tokens
(
draft_tokens
:
int
,
k
:
int
)
->
int
:
# Divide by k since batch size can be variable.
"""Calculate the number of emitted tokens, assuming all tokens are
total_num_spec_seqs
=
draft_tokens
/
k
accepted.
num_accepted_per_seq_if_all_accepted
=
k
+
1
return
int
(
total_num_spec_seqs
/
num_accepted_per_seq_if_all_accepted
)
This is equal to the number of sequences that have been speculated on,
times (speculation len + 1). The +1 comes from the bonus token.
"""
# Determine the number of sequences that have been speculated on. Since
# the batch size can be variable, we divide by k.
assert
draft_tokens
%
k
==
0
total_num_spec_seqs
=
draft_tokens
//
k
# A single sequence may emit k accepted tokens and one bonus token in
# the best case.
num_emitted_per_seq_if_all_accepted
=
k
+
1
# The max num of emitted tokens is the number of speculated sequences
# times the max emitted per seq.
return
total_num_spec_seqs
*
num_emitted_per_seq_if_all_accepted
vllm/spec_decode/multi_step_worker.py
View file @
99b471c2
...
@@ -25,7 +25,8 @@ class MultiStepWorker(Worker):
...
@@ -25,7 +25,8 @@ class MultiStepWorker(Worker):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_proposer
:
Optional
[
DraftModelTop1Proposer
]
=
None
# Lazy initialization list.
self
.
_proposer
:
DraftModelTop1Proposer
def
init_device
(
self
):
def
init_device
(
self
):
super
().
init_device
()
super
().
init_device
()
...
@@ -69,6 +70,9 @@ class MultiStepWorker(Worker):
...
@@ -69,6 +70,9 @@ class MultiStepWorker(Worker):
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
)
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
self
.
_append_new_tokens
(
model_output
,
self
.
_append_new_tokens
(
model_output
,
copied_seq_group_metadata_list
)
copied_seq_group_metadata_list
)
...
@@ -324,23 +328,25 @@ class DraftModelTop1Proposer(SpeculativeProposer):
...
@@ -324,23 +328,25 @@ class DraftModelTop1Proposer(SpeculativeProposer):
"""
"""
if
maybe_sampler_output
is
None
:
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# If no speculative tokens, the sampler output will be None.
# In this case we return empty tensors.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
zeros
(
0
,
proposal_tokens
=
torch
.
full
(
size
=
(
max_proposal_len
,
batch_size
,
dtype
=
torch
.
long
,
max_proposal_len
,
device
=
self
.
_device
)
),
proposal_probs
=
torch
.
zeros
(
0
,
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_probs
=
torch
.
zeros
(
batch_size
,
max_proposal_len
,
max_proposal_len
,
self
.
_vocab_size
,
self
.
_vocab_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
)
device
=
self
.
_device
)
proposal_lens
=
torch
.
zeros
(
len
(
proposal_lens
),
proposal_lens
_tensor
=
torch
.
zeros
(
len
(
proposal_lens
),
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
device
=
self
.
_device
)
return
proposal_tokens
,
proposal_probs
,
proposal_lens
return
proposal_tokens
,
proposal_probs
,
proposal_lens
_tensor
sampler_output
=
maybe_sampler_output
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
=
sampler_output_to_torch
(
proposal_tokens
,
proposal_probs
=
sampler_output_to_torch
(
sampler_output
)
sampler_output
)
...
@@ -362,9 +368,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
...
@@ -362,9 +368,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
)
entire_proposal_probs
)
proposal_lens
=
torch
.
zeros
(
batch_size
,
proposal_lens
_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
device
=
self
.
_device
)
proposal_lens
[
nonzero_proposal_len_indices
]
=
max_proposal_len
proposal_lens
_tensor
[
nonzero_proposal_len_indices
]
=
max_proposal_len
return
proposal_tokens
,
proposal_probs
,
proposal_lens
return
proposal_tokens
,
proposal_probs
,
proposal_lens
_tensor
vllm/spec_decode/spec_decode_worker.py
View file @
99b471c2
...
@@ -3,9 +3,9 @@ from typing import Dict, List, Optional, Tuple
...
@@ -3,9 +3,9 @@ from typing import Dict, List, Optional, Tuple
import
torch
import
torch
from
vllm.
config
import
CacheConfig
from
vllm.
logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.sequence
import
(
SamplerOutput
,
SequenceGroupMetadata
,
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceGroupOutput
,
SequenceOutput
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
...
@@ -14,10 +14,12 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
...
@@ -14,10 +14,12 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
split_batch_by_proposal_len
)
split_batch_by_proposal_len
)
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
_base
import
LoraNotSupportedWorkerBase
,
WorkerBase
logger
=
init_logger
(
__name__
)
class
SpecDecodeWorker
:
class
SpecDecodeWorker
(
LoraNotSupportedWorkerBase
):
"""Worker which implements speculative decoding.
"""Worker which implements speculative decoding.
Speculative decoding reduces decoding per-token latency by using a proposal
Speculative decoding reduces decoding per-token latency by using a proposal
...
@@ -45,10 +47,20 @@ class SpecDecodeWorker:
...
@@ -45,10 +47,20 @@ class SpecDecodeWorker:
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
"""
"""
@
classmethod
def
from_workers
(
cls
,
proposer_worker
:
MultiStepWorker
,
scorer_worker
:
WorkerBase
)
->
"SpecDecodeWorker"
:
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
# TODO(cade) disable strict mode for speedup.
rejection_sampler
=
RejectionSampler
(
strict_mode
=
True
),
)
def
__init__
(
def
__init__
(
self
,
self
,
proposer_worker
:
MultiStepWorker
,
proposer_worker
:
MultiStepWorker
,
scorer_worker
:
Worker
,
scorer_worker
:
Worker
Base
,
rejection_sampler
:
RejectionSampler
,
rejection_sampler
:
RejectionSampler
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
):
):
...
@@ -77,7 +89,8 @@ class SpecDecodeWorker:
...
@@ -77,7 +89,8 @@ class SpecDecodeWorker:
self
.
probs_dtype
=
self
.
rejection_sampler
.
probs_dtype
self
.
probs_dtype
=
self
.
rejection_sampler
.
probs_dtype
self
.
token_id_dtype
=
self
.
rejection_sampler
.
token_id_dtype
self
.
token_id_dtype
=
self
.
rejection_sampler
.
token_id_dtype
self
.
scorer
:
SpeculativeScorer
=
None
# Lazy initiazliation.
self
.
scorer
:
SpeculativeScorer
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
"""Initialize both scorer and proposer models.
"""Initialize both scorer and proposer models.
...
@@ -87,6 +100,10 @@ class SpecDecodeWorker:
...
@@ -87,6 +100,10 @@ class SpecDecodeWorker:
self
.
scorer_worker
.
init_device
()
self
.
scorer_worker
.
init_device
()
self
.
proposer_worker
.
init_device
()
self
.
proposer_worker
.
init_device
()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self
.
scorer_worker
.
load_model
()
self
.
proposer_worker
.
load_model
()
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
rejection_sampler
.
init_gpu_tensors
(
self
.
rank
)
self
.
rejection_sampler
.
init_gpu_tensors
(
self
.
rank
)
self
.
scorer
=
BatchExpansionTop1Scorer
(
self
.
scorer
=
BatchExpansionTop1Scorer
(
...
@@ -94,10 +111,33 @@ class SpecDecodeWorker:
...
@@ -94,10 +111,33 @@ class SpecDecodeWorker:
device
=
self
.
device
,
device
=
self
.
device
,
vocab_size
=
self
.
_vocab_size
)
vocab_size
=
self
.
_vocab_size
)
def
profile_num_available_blocks
(
self
,
block_size
:
int
,
self
.
_configure_model_sampler_for_spec_decode
()
gpu_memory_utilization
:
float
,
cpu_swap_space
:
int
,
def
_configure_model_sampler_for_spec_decode
(
self
):
cache_dtype
:
str
)
->
Tuple
[
int
,
int
]:
"""Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing,
which significantly reduces overhead of rejection sampling.
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
design is to have the "move to CPU and serialize" sampling decision be
done outside of the model/sampler; this way the "last-mile" worker
object which interfaces with the scheduler can serialize and incur the
performance hit as necessary. This allows us to run the worker several
iterations in a row without incurring the "move to CPU and serialize"
performance penalty.
Since this requires a large change to vLLM, we defer it to later and
temporarily accept this broken abstraction boundary.
NOTE(cade): This will require a special check if the proposer worker
does not have a sampler (e.g. ngram speculation).
"""
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
)
=
True
(
self
.
proposer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
)
=
True
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of cache blocks to use.
"""Determine the number of cache blocks to use.
This is done by profiling the scorer model (which is typically the
This is done by profiling the scorer model (which is typically the
...
@@ -106,27 +146,26 @@ class SpecDecodeWorker:
...
@@ -106,27 +146,26 @@ class SpecDecodeWorker:
such that the number of blocks is equal in both KV caches.
such that the number of blocks is equal in both KV caches.
"""
"""
num_gpu_blocks
,
num_cpu_blocks
=
(
num_gpu_blocks
,
num_cpu_blocks
=
(
self
.
scorer_worker
.
profile_num_available_blocks
(
self
.
scorer_worker
.
determine_num_available_blocks
())
block_size
,
gpu_memory_utilization
,
cpu_swap_space
,
cache_dtype
))
scorer_cache_block_size_bytes
=
(
scorer_cache_block_size_bytes
=
(
self
.
scorer_worker
.
get_cache_block_size_bytes
(
self
.
scorer_worker
.
get_cache_block_size_bytes
())
block_size
,
cache_dtype
))
proposer_cache_block_size_bytes
=
(
proposer_cache_block_size_bytes
=
(
self
.
proposer_worker
.
get_cache_block_size_bytes
(
self
.
proposer_worker
.
get_cache_block_size_bytes
())
block_size
,
cache_dtype
))
new_num_gpu_blocks
=
split_num_cache_blocks_evenly
(
new_num_gpu_blocks
=
split_num_cache_blocks_evenly
(
scorer_cache_block_size_bytes
,
proposer_cache_block_size_bytes
,
scorer_cache_block_size_bytes
,
proposer_cache_block_size_bytes
,
num_gpu_blocks
)
num_gpu_blocks
)
return
new_num_gpu_blocks
,
num_cpu_blocks
return
new_num_gpu_blocks
,
num_cpu_blocks
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
):
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the cache engine of the scorer and proposer workers.
"""Initialize the cache engine of the scorer and proposer workers.
"""
"""
self
.
scorer_worker
.
init_cache_engine
(
cache_config
)
self
.
scorer_worker
.
initialize_cache
(
num_gpu_blocks
=
num_gpu_blocks
,
self
.
proposer_worker
.
init_cache_engine
(
cache_config
)
num_cpu_blocks
=
num_cpu_blocks
)
self
.
proposer_worker
.
initialize_cache
(
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
...
@@ -135,7 +174,7 @@ class SpecDecodeWorker:
...
@@ -135,7 +174,7 @@ class SpecDecodeWorker:
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
num_
spec_token
s
:
int
,
num_
lookahead_slot
s
:
int
,
)
->
List
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
"""Perform speculative decoding on the input batch.
"""Perform speculative decoding on the input batch.
"""
"""
...
@@ -144,9 +183,11 @@ class SpecDecodeWorker:
...
@@ -144,9 +183,11 @@ class SpecDecodeWorker:
"speculative decoding "
"speculative decoding "
"requires non-None seq_group_metadata_list"
)
"requires non-None seq_group_metadata_list"
)
logger
.
info
(
f
"spec_decode_worker.execute_model
{
num_lookahead_slots
=
}
"
)
# If no spec tokens, call the proposer and scorer workers normally.
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
# Used for prefill.
if
num_
spec_token
s
==
0
or
len
(
seq_group_metadata_list
)
==
0
:
if
num_
lookahead_slot
s
==
0
or
len
(
seq_group_metadata_list
)
==
0
:
return
self
.
_run_no_spec
(
return
self
.
_run_no_spec
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
...
@@ -159,7 +200,7 @@ class SpecDecodeWorker:
...
@@ -159,7 +200,7 @@ class SpecDecodeWorker:
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
k
=
num_
spec_token
s
,
k
=
num_
lookahead_slot
s
,
)
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
...
@@ -174,20 +215,24 @@ class SpecDecodeWorker:
...
@@ -174,20 +215,24 @@ class SpecDecodeWorker:
proposer and scorer model so that the KV cache is consistent between the
proposer and scorer model so that the KV cache is consistent between the
two.
two.
"""
"""
logger
.
info
(
"run proposer worker no spec"
)
self
.
proposer_worker
.
execute_model
(
self
.
proposer_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
return_python_output
=
False
)
)
logger
.
info
(
"run target worker no spec"
)
sampler_output
=
self
.
scorer_worker
.
execute_model
(
sampler_output
=
self
.
scorer_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
)
)
assert
len
(
sampler_output
)
==
1
sampler_output
=
sampler_output
[
0
]
# Clear device tensors from sampler output. This reduces communication
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
# overhead when the engine runs in a different process than the workers.
...
@@ -213,11 +258,16 @@ class SpecDecodeWorker:
...
@@ -213,11 +258,16 @@ class SpecDecodeWorker:
sequence.
sequence.
"""
"""
logger
.
info
(
"get spec proposals"
)
# Generate proposals using draft worker.
# Generate proposals using draft worker.
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
k
)
blocks_to_copy
,
k
)
logger
.
info
(
"score proposals"
)
proposal_scores
=
self
.
scorer
.
score_proposals
(
proposal_scores
=
self
.
scorer
.
score_proposals
(
seq_group_metadata_list
,
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_in
,
...
@@ -227,9 +277,11 @@ class SpecDecodeWorker:
...
@@ -227,9 +277,11 @@ class SpecDecodeWorker:
proposals
,
proposals
,
)
)
logger
.
info
(
"verify proposals"
)
accepted_token_ids
=
self
.
_verify_tokens
(
seq_group_metadata_list
,
accepted_token_ids
=
self
.
_verify_tokens
(
seq_group_metadata_list
,
proposal_scores
,
proposals
,
k
)
proposal_scores
,
proposals
,
k
)
logger
.
info
(
"create output list"
)
return
self
.
_create_output_sampler_list
(
seq_group_metadata_list
,
return
self
.
_create_output_sampler_list
(
seq_group_metadata_list
,
accepted_token_ids
,
k
)
accepted_token_ids
,
k
)
...
@@ -260,15 +312,26 @@ class SpecDecodeWorker:
...
@@ -260,15 +312,26 @@ class SpecDecodeWorker:
select_proposal_len_zero
=
True
)
select_proposal_len_zero
=
True
)
original_indices
=
spec_indices
+
non_spec_indices
original_indices
=
spec_indices
+
non_spec_indices
proposal_probs
=
proposal_scores
.
probs
[
spec_indices
,
:
-
1
]
# Get probabilities of target model, excluding bonus token.
bonus_token_ids
=
proposal_scores
.
token_ids
[
spec_indices
,
-
1
:]
proposal_verifier_probs
=
proposal_scores
.
probs
[
spec_indices
,
:
-
1
]
# Get non-speculative sampled tokens from target model.
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
# Get bonus tokens from target model.
bonus_token_ids
=
proposal_scores
.
token_ids
[
spec_indices
,
-
1
:]
# Get probabilities according to proposal method.
proposal_probs
=
proposals
.
proposal_probs
[
spec_indices
]
# Get proposed tokens.
proposal_token_ids
=
proposals
.
proposal_token_ids
[
spec_indices
]
accepted_token_ids
=
self
.
rejection_sampler
(
accepted_token_ids
=
self
.
rejection_sampler
(
proposal
_probs
,
target_probs
=
proposal_verifier
_probs
,
bonus_token_ids
,
bonus_token_ids
=
bonus_token_ids
,
proposals
.
proposal_probs
,
draft_probs
=
proposal_probs
,
proposals
.
proposal_token_ids
,
draft_token_ids
=
proposal_token_ids
,
)
)
# Append output tokens from non-speculative sequences to
# Append output tokens from non-speculative sequences to
...
@@ -315,7 +378,7 @@ class SpecDecodeWorker:
...
@@ -315,7 +378,7 @@ class SpecDecodeWorker:
parent_seq_id
=
seq_id
,
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
output_token
=
token_id
,
# TODO Add verifier logprobs.
# TODO Add verifier logprobs.
logprobs
=
{
token_id
:
0.0
},
logprobs
=
{
token_id
:
Logprob
(
0.0
)
},
)
)
],
],
prompt_logprobs
=
None
,
prompt_logprobs
=
None
,
...
@@ -351,6 +414,16 @@ class SpecDecodeWorker:
...
@@ -351,6 +414,16 @@ class SpecDecodeWorker:
def
device
(
self
):
def
device
(
self
):
return
self
.
scorer_worker
.
device
return
self
.
scorer_worker
.
device
def
get_cache_block_size_bytes
(
self
):
"""Return the size of a cache block in bytes.
This function is only used to compose workers within a SpecDecodeWorker.
We leave composing a SpecDecodeWorker within a SpecDecodeWorker
undefined for now, although it could be implemented in the future.
See https://arxiv.org/abs/2308.04623.
"""
raise
NotImplementedError
def
split_num_cache_blocks_evenly
(
scorer_cache_block_size_bytes
:
int
,
def
split_num_cache_blocks_evenly
(
scorer_cache_block_size_bytes
:
int
,
proposer_cache_block_size_bytes
:
int
,
proposer_cache_block_size_bytes
:
int
,
...
...
vllm/spec_decode/util.py
View file @
99b471c2
...
@@ -82,6 +82,32 @@ def sampler_output_to_torch(
...
@@ -82,6 +82,32 @@ def sampler_output_to_torch(
return
sampled_token_ids
,
sampled_token_probs
return
sampled_token_ids
,
sampled_token_probs
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
vocab_size
:
int
,
device
:
str
)
->
None
:
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
values. This will be removed in PR 7/9.
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
"""
values
=
[
sampler_output
.
sampled_token_probs
,
sampler_output
.
sampled_token_ids
]
assert
all
(
v
is
None
for
v
in
values
)
or
not
any
(
v
is
None
for
v
in
values
)
if
not
any
(
v
is
None
for
v
in
values
):
# Do nothing if the tensors are already created (usually in unit tests).
return
# Softmax to ensure valid probs.
sampler_output
.
sampled_token_probs
=
torch
.
nn
.
functional
.
softmax
(
torch
.
rand
(
batch_size
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
device
),
dim
=-
1
)
sampler_output
.
sampled_token_ids
=
torch
.
randint
(
low
=
10
,
high
=
100
,
size
=
(
batch_size
,
),
dtype
=
torch
.
long
,
device
=
device
)
@
contextmanager
@
contextmanager
def
nvtx_range
(
msg
,
*
args
,
**
kwargs
):
def
nvtx_range
(
msg
,
*
args
,
**
kwargs
):
"""
"""
...
...
vllm/test_utils.py
View file @
99b471c2
import
ray
import
ray
from
vllm.config
import
ParallelConfig
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.utils
import
get_open_port
from
vllm.utils
import
get_open_port
from
vllm.worker.worker
import
init_distributed_environment
def
init_test_distributed_environment
(
def
init_test_distributed_environment
(
...
@@ -12,15 +12,14 @@ def init_test_distributed_environment(
...
@@ -12,15 +12,14 @@ def init_test_distributed_environment(
distributed_init_port
:
str
,
distributed_init_port
:
str
,
local_rank
:
int
=
-
1
,
local_rank
:
int
=
-
1
,
)
->
None
:
)
->
None
:
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
,
tensor_parallel_size
,
worker_use_ray
=
True
)
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
init_distributed_environment
(
init_distributed_environment
(
parallel_config
,
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
,
rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
local_rank
)
local_rank
=
local_rank
)
ensure_model_parallel_initialized
(
tensor_parallel_size
,
pipeline_parallel_size
)
def
multi_process_tensor_parallel
(
def
multi_process_tensor_parallel
(
...
...
vllm/transformers_utils/config.py
View file @
99b471c2
from
typing
import
Optional
from
typing
import
Dict
,
Optional
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers
import
AutoConfig
,
PretrainedConfig
from
vllm.transformers_utils.configs
import
*
from
vllm.transformers_utils.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
JAISConfig
,
MPTConfig
,
RWConfig
)
_CONFIG_REGISTRY
=
{
_CONFIG_REGISTRY
:
Dict
[
str
,
PretrainedConfig
]
=
{
"chatglm"
:
ChatGLMConfig
,
"chatglm"
:
ChatGLMConfig
,
"dbrx"
:
DbrxConfig
,
"dbrx"
:
DbrxConfig
,
"mpt"
:
MPTConfig
,
"mpt"
:
MPTConfig
,
...
...
vllm/transformers_utils/configs/dbrx.py
View file @
99b471c2
...
@@ -12,7 +12,7 @@ from transformers.utils import logging
...
@@ -12,7 +12,7 @@ from transformers.utils import logging
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
# type: ignore
class
DbrxAttentionConfig
(
PretrainedConfig
):
class
DbrxAttentionConfig
(
PretrainedConfig
):
...
...
Prev
1
…
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment