Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
500b93c8
Commit
500b93c8
authored
Jul 25, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1
parents
99426767
38c4b7e8
Changes
282
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
878 additions
and
395 deletions
+878
-395
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+118
-119
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+6
-14
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+83
-5
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+31
-41
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+22
-66
vllm/platforms/__init__.py
vllm/platforms/__init__.py
+7
-2
vllm/platforms/interface.py
vllm/platforms/interface.py
+21
-0
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+17
-0
vllm/sampling_params.py
vllm/sampling_params.py
+0
-13
vllm/sequence.py
vllm/sequence.py
+7
-5
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+12
-1
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+236
-80
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+3
-0
vllm/spec_decode/metrics.py
vllm/spec_decode/metrics.py
+4
-0
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+13
-4
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+2
-2
vllm/spec_decode/proposer_worker_base.py
vllm/spec_decode/proposer_worker_base.py
+1
-1
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+223
-39
vllm/spec_decode/target_model_runner.py
vllm/spec_decode/target_model_runner.py
+69
-0
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+3
-3
No files found.
vllm/model_executor/models/phi3v.py
View file @
500b93c8
...
@@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
...
@@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
input_processor_for_clip
)
input_processor_for_clip
)
from
.interfaces
import
SupportsVision
from
.interfaces
import
SupportsVision
from
.utils
import
merge_vision_embeddings
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -71,9 +72,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
...
@@ -71,9 +72,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
class
Phi3ImageEmbeddingBase
(
nn
.
Module
):
class
Phi3ImageEmbeddingBase
(
nn
.
Module
):
def
__init__
(
self
,
wte
=
None
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
wte
=
wte
self
.
layer_idx
:
int
self
.
layer_idx
:
int
self
.
type_feature
:
str
self
.
type_feature
:
str
self
.
img_processor
:
CLIPVisionModel
self
.
img_processor
:
CLIPVisionModel
...
@@ -100,10 +100,9 @@ class Phi3ImageEmbeddingBase(nn.Module):
...
@@ -100,10 +100,9 @@ class Phi3ImageEmbeddingBase(nn.Module):
class
Phi3HDImageEmbedding
(
Phi3ImageEmbeddingBase
):
class
Phi3HDImageEmbedding
(
Phi3ImageEmbeddingBase
):
"""Phi3 Image embedding with HD transform."""
"""Phi3 Image embedding with HD transform."""
def
__init__
(
self
,
config
:
PretrainedConfig
,
wte
=
None
)
->
None
:
def
__init__
(
self
,
config
:
PretrainedConfig
)
->
None
:
super
().
__init__
(
wte
)
super
().
__init__
()
self
.
image_token_id
=
_IMAGE_TOKEN_ID
# n_embed or hidden_size
# n_embed or hidden_size
hidden_size
=
config
.
n_embd
if
hasattr
(
hidden_size
=
config
.
n_embd
if
hasattr
(
config
,
'n_embd'
)
else
config
.
hidden_size
config
,
'n_embd'
)
else
config
.
hidden_size
...
@@ -149,118 +148,115 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
...
@@ -149,118 +148,115 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
nn
.
Linear
(
dim_projection
,
dim_projection
)])
nn
.
Linear
(
dim_projection
,
dim_projection
)])
self
.
img_projection
=
nn
.
Sequential
(
*
layers
)
self
.
img_projection
=
nn
.
Sequential
(
*
layers
)
self
.
vocab_size
=
config
.
vocab_size
self
.
type_feature
=
config
.
img_processor
.
get
(
'type_feature'
,
'patch'
)
self
.
type_feature
=
config
.
img_processor
.
get
(
'type_feature'
,
'patch'
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
,
pixel_values
:
torch
.
FloatTensor
,
image_sizes
:
torch
.
Tensor
)
->
torch
.
FloatTensor
:
image_sizes
:
torch
.
Tensor
)
->
torch
.
FloatTensor
:
"""process and merge text embeddings with image embeddings."""
"""
process image and return vision embeddings.
# (batch_size, max_num_crops, 3, height, width)
img_embeds
=
pixel_values
pixel_values: (num_images, num_crops, c, h, w)
output: (num_images, num_img_tokens, hidden_size)
# (batch_size, 2)
"""
img_sizes
=
image_sizes
num_images
,
num_crops
,
c
,
h
,
w
=
pixel_values
.
shape
pixel_values
=
pixel_values
.
flatten
(
0
,
1
)
input_shape
=
input_ids
.
size
()
img_features
=
self
.
get_img_features
(
pixel_values
)
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
img_features
=
img_features
.
reshape
(
num_images
,
num_crops
,
-
1
,
self
.
image_dim_out
)
positions
=
torch
.
nonzero
(
input_ids
==
self
.
image_token_id
)
image_features_proj
=
self
.
hd_feature_transform
(
img_features
,
image_sizes
)
select
=
False
return
image_features_proj
target_dtype
=
self
.
img_projection
[
0
].
bias
.
dtype
def
hd_feature_transform
(
self
,
image_features
,
image_sizes
):
"""
if
len
(
positions
.
tolist
())
>
0
:
image_features: (num_images, num_crops+1, 24*24, 1024)
# if self.use_hd_transform and img_sizes:
"""
# img_embeds: (num_images, max_num_crops, 3, H, W)
assert
(
# img_sizes: (num_images, 2).view(1, -1)
self
.
hd_transform_order
==
'sub_glb'
),
f
'hd_transform_order `
{
self
.
hd_transform_order
}
` not implemented'
bs
=
img_embeds
.
shape
[
0
]
if
isinstance
(
self
.
img_projection
,
nn
.
Sequential
):
# Nx(HW)xC
target_device
=
self
.
img_projection
[
0
].
bias
.
device
img_features
=
self
.
get_img_features
(
img_embeds
.
flatten
(
0
,
1
))
target_dtype
=
self
.
img_projection
[
0
].
bias
.
dtype
base_feat_height
=
base_feat_width
=
int
(
else
:
# It's a single nn.Linear layer
img_features
.
shape
[
1
]
**
0.5
)
target_device
=
self
.
img_projection
.
bias
.
device
target_dtype
=
self
.
img_projection
.
bias
.
dtype
# bs x max_num_crops x (24x24) x C
img_features
=
img_features
.
view
(
global_image_features
=
image_features
[:,
bs
,
-
1
,
base_feat_height
*
base_feat_width
,
self
.
image_dim_out
)
0
]
# (num_images, 24*24, 1024)
C
=
self
.
image_dim_out
# global feature can be viewed as a special HD case with num_crops 1x1
H
=
base_feat_height
global_image_features_hd
=
self
.
reshape_hd_patches_2x2merge
(
global_image_features
,
1
,
1
)
output_imgs
=
[]
global_image_features_hd_newline
=
self
.
add_image_newline
(
output_len
=
[]
global_image_features_hd
)
for
_bs
in
range
(
bs
):
all_image_embeddings
=
[]
h
,
w
=
img_sizes
[
_bs
]
# need a for loop to process each image because of different image sizes
h
=
h
//
336
# (patch arrangement is different for each image)
w
=
w
//
336
for
i
,
img_size
in
enumerate
(
image_sizes
):
B_
=
h
*
w
h
,
w
=
img_size
h_crop
=
h
//
336
# 1 x (24x24) x 1024
w_crop
=
w
//
336
global_img_feature
=
img_features
[
_bs
,
:
1
]
num_crops
=
h_crop
*
w_crop
# 1 x 12 x 12 x 4096
# NOTE: real num_crops is padded
glb_img
=
global_img_feature
\
# (num_crops, 24*24, 1024)
.
reshape
(
1
,
H
//
2
,
2
,
H
//
2
,
2
,
C
)
\
sub_image_features
=
image_features
[
i
,
1
:
1
+
num_crops
]
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
)
\
sub_image_features_hd
=
self
.
reshape_hd_patches_2x2merge
(
.
reshape
(
1
,
H
//
2
,
H
//
2
,
4
*
C
)
sub_image_features
,
h_crop
,
w_crop
)
temp_glb_GN
=
self
.
sub_GN
.
repeat
(
1
,
H
//
2
,
1
,
1
)
sub_image_features_hd_newline
=
self
.
add_image_newline
(
sub_image_features_hd
)
# 1 x 156 x 4096
glb_img
=
torch
.
cat
([
glb_img
,
temp_glb_GN
],
# [sub features, separator, global features]
dim
=
2
).
reshape
(
1
,
-
1
,
4
*
C
)
all_image_embeddings
.
append
(
torch
.
cat
([
# (max_num_crops-1) x (12x12) x C
sub_image_features_hd_newline
.
squeeze
(
sub_img
=
img_features
[
_bs
,
1
:]
0
),
# (h_crop*12*(w_crop*12+1), 4096)
# 16x574x1024
self
.
glb_GN
.
squeeze
(
0
),
# get rid of padding sub_img
global_image_features_hd_newline
[
i
],
sub_img
=
sub_img
[:
B_
]
]))
sub_img
=
sub_img
.
reshape
(
B_
,
H
//
2
,
2
,
H
//
2
,
2
,
C
)
\
image_features_proj
=
self
.
img_projection
(
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
reshape
(
B_
,
-
1
,
4
*
C
)
torch
.
stack
(
all_image_embeddings
).
to
(
target_device
,
target_dtype
)
sub_img
=
sub_img
.
reshape
(
1
,
h
,
w
,
12
,
12
,
-
1
)
\
)
# (num_images, (h_crop*12*(w_crop*12+1)+1), hidden_size)
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
)
\
.
reshape
(
1
,
h
*
12
,
w
*
12
,
4
*
C
)
return
image_features_proj
temp_sub_GN
=
self
.
sub_GN
.
repeat
(
1
,
h
*
12
,
1
,
1
)
sub_img
=
torch
.
cat
([
sub_img
,
temp_sub_GN
],
def
reshape_hd_patches_2x2merge
(
self
,
image_features
,
h_crop
,
w_crop
):
dim
=
2
).
reshape
(
1
,
-
1
,
4
*
C
)
"""
# (1, num_img_tokens, 1024*4)
image_features: (num_images*num_crops, 24*24, 1024)
output: (num_images, h_crop*12, w_crop*12, 4096)
# glb + sub
where h_crop*w_crop == num_crops
if
self
.
hd_transform_order
==
'glb_sub'
:
"""
output_imgs
.
append
(
N
,
L
,
C
=
image_features
.
shape
torch
.
cat
([
glb_img
,
self
.
glb_GN
,
sub_img
],
dim
=
1
))
assert
L
==
576
and
C
==
1024
and
N
%
(
h_crop
*
w_crop
)
==
0
elif
self
.
hd_transform_order
==
'sub_glb'
:
num_images
=
N
//
(
h_crop
*
w_crop
)
output_imgs
.
append
(
H
=
int
(
L
**
0.5
)
torch
.
cat
([
sub_img
,
self
.
glb_GN
,
glb_img
],
dim
=
1
))
image_features_hd
=
(
image_features
.
reshape
(
N
,
H
,
H
,
C
)
# N, 24, 24, 1024
temp_len
=
int
((
h
*
w
+
1
)
*
144
+
1
+
(
h
+
1
)
*
12
)
.
reshape
(
N
,
H
//
2
,
2
,
H
//
2
,
2
,
C
)
# N, 12, 2, 12, 2, 1024
output_len
.
append
(
temp_len
)
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
)
# N, 12, 12, 2, 2, 1024
.
reshape
(
N
,
-
1
,
4
*
C
)
# N, 144, 4096
num_img_tokens
=
output_len
.
reshape
(
num_images
,
h_crop
,
w_crop
,
H
//
2
,
H
//
2
,
img_set_tensor
=
[]
-
1
)
# n_img, h_crop, w_crop, 12, 12, 4096
for
_output_img
in
output_imgs
:
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
)
# n_img, h_crop, 12, w_crop, 12, 4096
img_feature_proj
=
self
.
img_projection
(
.
reshape
(
num_images
,
h_crop
*
H
//
2
,
w_crop
*
H
//
2
,
_output_img
.
to
(
target_dtype
))
4
*
C
)
# n_img, h_crop*12, w_crop*12, 4096
img_set_tensor
.
append
(
img_feature_proj
)
)
select
=
True
return
image_features_hd
input_ids
.
clamp_min_
(
0
).
clamp_max_
(
self
.
vocab_size
)
def
add_image_newline
(
self
,
image_features_hd
):
"""
hidden_states
=
self
.
wte
(
input_ids
)
image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
if
select
:
"""
idx
=
0
num_images
,
h
,
w
,
hid_dim
=
image_features_hd
.
shape
for
i
,
cnt
in
enumerate
(
num_img_tokens
):
# add the newline token to the HD image feature patches
hidden_states
[
positions
[
idx
,
0
],
newline_embeddings
=
self
.
sub_GN
.
expand
(
num_images
,
h
,
-
1
,
positions
[
idx
,
1
]:
positions
[
idx
,
1
]
+
-
1
)
# (n_img, h, 1, hid_dim)
cnt
]
=
(
img_set_tensor
[
i
].
to
(
image_features_hd_newline
=
torch
.
cat
(
hidden_states
.
dtype
))
[
image_features_hd
,
newline_embeddings
],
idx
+=
cnt
dim
=
2
).
reshape
(
num_images
,
-
1
,
hid_dim
)
return
image_features_hd_newline
return
hidden_states
.
squeeze
(
0
)
class
Phi3VImagePixelInputs
(
TypedDict
):
class
Phi3VImagePixelInputs
(
TypedDict
):
...
@@ -458,12 +454,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
...
@@ -458,12 +454,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
self
.
image_token_id
=
_IMAGE_TOKEN_ID
self
.
model
=
LlamaModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
LlamaModel
(
config
,
cache_config
,
quant_config
)
# TODO: Optionally initializes this for supporting embeddings.
# TODO: Optionally initializes this for supporting embeddings.
self
.
vision_embed_tokens
=
Phi3HDImageEmbedding
(
self
.
vision_embed_tokens
=
Phi3HDImageEmbedding
(
config
)
config
,
self
.
model
.
embed_tokens
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
...
@@ -530,9 +526,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
...
@@ -530,9 +526,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
if
image_input
is
not
None
:
inputs_embeds
=
self
.
vision_embed_tokens
(
vision_embeddings
=
self
.
vision_embed_tokens
(
input_ids
,
image_input
[
"data"
],
image_input
[
"image_sizes"
])
image_input
[
"data"
],
image_input
[
"image_sizes"
])
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_vision_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
image_token_id
)
input_ids
=
None
input_ids
=
None
else
:
else
:
inputs_embeds
=
None
inputs_embeds
=
None
...
...
vllm/model_executor/models/qwen2.py
View file @
500b93c8
...
@@ -45,10 +45,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -45,10 +45,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
print_warning_once
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
...
@@ -392,18 +392,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -392,18 +392,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Remapping the name of FP8 kv-scale.
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"kv_scale"
):
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
remapped_kv_scale_name
=
name
.
replace
(
if
name
is
None
:
".kv_scale"
,
".attn.kv_scale"
)
continue
if
remapped_kv_scale_name
not
in
params_dict
:
print_warning_once
(
f
"Found kv scale in the checkpoint (e.g.
{
name
}
), "
"but not found the expected name in the model "
f
"(e.g.
{
remapped_kv_scale_name
}
). kv-scale is "
"not loaded."
)
continue
else
:
name
=
remapped_kv_scale_name
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/utils.py
View file @
500b93c8
from
typing
import
Callable
,
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Protocol
,
Tuple
import
torch
import
torch
from
torch.func
import
functional_call
from
vllm.multimodal
import
BatchedTensors
from
vllm.multimodal
import
BatchedTensors
from
vllm.utils
import
is_pin_memory_available
def
merge_vision_embeddings
(
input_ids
:
torch
.
Tensor
,
def
merge_vision_embeddings
(
input_ids
:
torch
.
Tensor
,
...
@@ -43,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
...
@@ -43,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
return
inputs_embeds
return
inputs_embeds
class
LayerFn
(
Protocol
):
def
__call__
(
self
,
prefix
=
""
,
)
->
torch
.
nn
.
Module
:
...
class
PPMissingLayer
(
torch
.
nn
.
Identity
):
class
PPMissingLayer
(
torch
.
nn
.
Identity
):
"""
"""
A placeholder layer for missing layers in a pipeline parallel model.
A placeholder layer for missing layers in a pipeline parallel model.
...
@@ -52,8 +63,74 @@ class PPMissingLayer(torch.nn.Identity):
...
@@ -52,8 +63,74 @@ class PPMissingLayer(torch.nn.Identity):
super
().
__init__
()
super
().
__init__
()
_CPU_OFFLOAD_BYTES
=
0
_CPU_OFFLOAD_MAX_BYTES
=
0
def
set_cpu_offload_max_bytes
(
max_bytes
:
int
)
->
None
:
global
_CPU_OFFLOAD_MAX_BYTES
,
_CPU_OFFLOAD_BYTES
_CPU_OFFLOAD_BYTES
=
0
_CPU_OFFLOAD_MAX_BYTES
=
max_bytes
def
maybe_offload_to_cpu
(
module
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
device
=
next
(
module
.
parameters
()).
device
if
device
==
torch
.
device
(
"cpu"
):
return
module
global
_CPU_OFFLOAD_MAX_BYTES
,
_CPU_OFFLOAD_BYTES
if
_CPU_OFFLOAD_BYTES
>=
_CPU_OFFLOAD_MAX_BYTES
:
return
module
pin_memory
=
is_pin_memory_available
()
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
for
p
in
module
.
parameters
():
if
_CPU_OFFLOAD_BYTES
>=
_CPU_OFFLOAD_MAX_BYTES
:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty
(
size
=
p
.
data
.
size
(),
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
device
=
'cpu'
,
pin_memory
=
pin_memory
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
_CPU_OFFLOAD_BYTES
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
module
.
state_dict
()
original_forward
=
module
.
forward
def
forward
(
*
args
,
**
kwargs
):
module
.
forward
=
original_forward
device_state
=
{
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k
:
v
.
to
(
device
,
non_blocking
=
True
)
for
k
,
v
in
state_dict
.
items
()
}
output
=
functional_call
(
module
,
device_state
,
args
=
args
,
kwargs
=
kwargs
)
module
.
forward
=
forward
return
output
module
.
forward
=
forward
return
module
def
make_layers
(
def
make_layers
(
num_hidden_layers
:
int
,
layer_fn
:
Callable
[[],
torch
.
nn
.
Module
]
num_hidden_layers
:
int
,
layer_fn
:
LayerFn
,
prefix
:
str
,
)
->
Tuple
[
int
,
int
,
torch
.
nn
.
ModuleList
]:
)
->
Tuple
[
int
,
int
,
torch
.
nn
.
ModuleList
]:
"""Make a list of layers with the given layer function, taking
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
pipeline parallelism into account.
...
@@ -64,9 +141,10 @@ def make_layers(
...
@@ -64,9 +141,10 @@ def make_layers(
get_pp_group
().
rank_in_group
,
get_pp_group
().
rank_in_group
,
get_pp_group
().
world_size
)
get_pp_group
().
world_size
)
modules
=
torch
.
nn
.
ModuleList
(
modules
=
torch
.
nn
.
ModuleList
(
[
PPMissingLayer
()
for
_
in
range
(
start_layer
)]
+
[
PPMissingLayer
()
for
_
in
range
(
start_layer
)]
+
[
[
layer_fn
()
for
_
in
range
(
start_layer
,
end_layer
)]
+
maybe_offload_to_cpu
(
layer_fn
(
prefix
=
f
"
{
prefix
}
.
{
idx
}
"
))
[
PPMissingLayer
()
for
_
in
range
(
end_layer
,
num_hidden_layers
)])
for
idx
in
range
(
start_layer
,
end_layer
)
]
+
[
PPMissingLayer
()
for
_
in
range
(
end_layer
,
num_hidden_layers
)])
return
start_layer
,
end_layer
,
modules
return
start_layer
,
end_layer
,
modules
...
...
vllm/model_executor/sampling_metadata.py
View file @
500b93c8
...
@@ -8,7 +8,7 @@ from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
...
@@ -8,7 +8,7 @@ from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
(
async_tensor_h2d
,
is_pin_memory_available
,
from
vllm.utils
import
(
async_tensor_h2d
,
is_pin_memory_available
,
maybe_expand_dim
)
make_tensor_with_pad
,
maybe_expand_dim
)
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
_SEED_0_REPLACEMENT
=
3403598558
_SEED_0_REPLACEMENT
=
3403598558
...
@@ -86,6 +86,12 @@ class SamplingMetadata:
...
@@ -86,6 +86,12 @@ class SamplingMetadata:
The first tuple is [1, 2] (sampled index within original logit),
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
num_prompts: Number of prompt sequence groups in seq_groups.
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
serialization of token outputs.
reuse_sampling_tensors: Indicates if we want to reuse sampling
tensors that are part of the sampler forward pass. Currently,
it is mainly used for multi-step decode.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -94,11 +100,15 @@ class SamplingMetadata:
...
@@ -94,11 +100,15 @@ class SamplingMetadata:
selected_token_indices
:
torch
.
Tensor
,
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Dict
[
SamplingType
,
torch
.
Tensor
],
categorized_sample_indices
:
Dict
[
SamplingType
,
torch
.
Tensor
],
num_prompts
:
int
,
num_prompts
:
int
,
skip_sampler_cpu_output
:
bool
=
False
,
reuse_sampling_tensors
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_groups
=
seq_groups
self
.
selected_token_indices
=
selected_token_indices
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
num_prompts
=
num_prompts
self
.
num_prompts
=
num_prompts
self
.
skip_sampler_cpu_output
=
skip_sampler_cpu_output
self
.
reuse_sampling_tensors
=
reuse_sampling_tensors
@
staticmethod
@
staticmethod
def
prepare
(
def
prepare
(
...
@@ -455,18 +465,24 @@ class SamplingTensors:
...
@@ -455,18 +465,24 @@ class SamplingTensors:
do_penalties
=
prompt_tokens
or
output_tokens
do_penalties
=
prompt_tokens
or
output_tokens
if
do_penalties
:
if
do_penalties
:
prompt_max_len
=
max
([
len
(
tokens
)
for
tokens
in
prompt_tokens
],
prompt_t
=
make_tensor_with_pad
(
default
=
0
)
prompt_tokens
,
prompt_padded_tokens
=
[
vocab_size
,
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
device
=
"cpu"
,
for
tokens
in
prompt_tokens
dtype
=
torch
.
int64
,
]
pin_memory
=
pin_memory
,
output_max_len
=
max
([
len
(
tokens
)
for
tokens
in
output_tokens
],
)
default
=
0
)
output_t
=
make_tensor_with_pad
(
output_padded_tokens
=
[
output_tokens
,
tokens
+
[
vocab_size
]
*
(
output_max_len
-
len
(
tokens
))
vocab_size
,
for
tokens
in
output_tokens
device
=
"cpu"
,
]
dtype
=
torch
.
int64
,
pin_memory
=
pin_memory
,
)
else
:
empty_tensor
=
torch
.
empty
(
0
,
device
=
device
,
dtype
=
torch
.
long
)
prompt_t
=
empty_tensor
output_t
=
empty_tensor
temperatures_t
=
torch
.
tensor
(
temperatures_t
=
torch
.
tensor
(
temperatures
,
temperatures
,
...
@@ -516,22 +532,6 @@ class SamplingTensors:
...
@@ -516,22 +532,6 @@ class SamplingTensors:
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
)
)
if
do_penalties
:
prompt_tensor
=
torch
.
tensor
(
prompt_padded_tokens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
output_tensor
=
torch
.
tensor
(
output_padded_tokens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
else
:
prompt_tensor
=
None
output_tensor
=
None
# need to transpose and make contiguous to
# need to transpose and make contiguous to
# copy the tensor correctly.
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
# [batch_size, n_seeds] -> [n_seeds, batch_size]
...
@@ -554,16 +554,6 @@ class SamplingTensors:
...
@@ -554,16 +554,6 @@ class SamplingTensors:
extra_seeds_gpu
=
None
extra_seeds_gpu
=
None
sampling_seeds_gpu
=
sampling_seeds_gpu
[:
num_base_seeds
]
sampling_seeds_gpu
=
sampling_seeds_gpu
[:
num_base_seeds
]
if
do_penalties
:
prompt_tokens_gpu
=
prompt_tensor
.
to
(
device
=
device
,
non_blocking
=
True
)
output_tokens_gpu
=
output_tensor
.
to
(
device
=
device
,
non_blocking
=
True
)
else
:
empty_tensor
=
torch
.
empty
(
0
,
device
=
device
,
dtype
=
torch
.
long
)
prompt_tokens_gpu
=
empty_tensor
output_tokens_gpu
=
empty_tensor
return
cls
(
return
cls
(
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
...
@@ -575,8 +565,8 @@ class SamplingTensors:
...
@@ -575,8 +565,8 @@ class SamplingTensors:
non_blocking
=
True
),
non_blocking
=
True
),
repetition_penalties
=
repetition_penalties_t
.
to
(
device
=
device
,
repetition_penalties
=
repetition_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
non_blocking
=
True
),
prompt_tokens
=
prompt_t
okens_gpu
,
prompt_tokens
=
prompt_t
.
to
(
device
=
device
,
non_blocking
=
True
)
,
output_tokens
=
output_t
okens_gpu
,
output_tokens
=
output_t
.
to
(
device
=
device
,
non_blocking
=
True
)
,
sampling_seeds
=
sampling_seeds_gpu
,
sampling_seeds
=
sampling_seeds_gpu
,
sample_indices
=
sample_indices_t
.
to
(
device
=
device
,
sample_indices
=
sample_indices_t
.
to
(
device
=
device
,
non_blocking
=
True
),
non_blocking
=
True
),
...
...
vllm/multimodal/utils.py
View file @
500b93c8
import
base64
import
base64
from
io
import
BytesIO
from
io
import
BytesIO
from
typing
import
Optional
,
Union
from
typing
import
Union
from
urllib.parse
import
urlparse
import
aiohttp
import
requests
from
PIL
import
Image
from
PIL
import
Image
from
vllm.connections
import
global_http_connection
from
vllm.envs
import
VLLM_IMAGE_FETCH_TIMEOUT
from
vllm.envs
import
VLLM_IMAGE_FETCH_TIMEOUT
from
vllm.multimodal.base
import
MultiModalDataDict
from
vllm.multimodal.base
import
MultiModalDataDict
from
vllm.version
import
__version__
as
VLLM_VERSION
def
_validate_remote_url
(
url
:
str
,
*
,
name
:
str
):
parsed_url
=
urlparse
(
url
)
if
parsed_url
.
scheme
not
in
[
"http"
,
"https"
]:
raise
ValueError
(
f
"Invalid '
{
name
}
': A valid '
{
name
}
' "
"must have scheme 'http' or 'https'."
)
def
_get_request_headers
():
return
{
"User-Agent"
:
f
"vLLM/
{
VLLM_VERSION
}
"
}
def
_load_image_from_bytes
(
b
:
bytes
):
def
_load_image_from_bytes
(
b
:
bytes
):
...
@@ -42,13 +28,8 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
...
@@ -42,13 +28,8 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
By default, the image is converted into RGB format.
By default, the image is converted into RGB format.
"""
"""
if
image_url
.
startswith
(
'http'
):
if
image_url
.
startswith
(
'http'
):
_validate_remote_url
(
image_url
,
name
=
"image_url"
)
image_raw
=
global_http_connection
.
get_bytes
(
image_url
,
timeout
=
VLLM_IMAGE_FETCH_TIMEOUT
)
headers
=
_get_request_headers
()
with
requests
.
get
(
url
=
image_url
,
headers
=
headers
)
as
response
:
response
.
raise_for_status
()
image_raw
=
response
.
content
image
=
_load_image_from_bytes
(
image_raw
)
image
=
_load_image_from_bytes
(
image_raw
)
elif
image_url
.
startswith
(
'data:image'
):
elif
image_url
.
startswith
(
'data:image'
):
...
@@ -60,55 +41,30 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
...
@@ -60,55 +41,30 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
return
image
.
convert
(
image_mode
)
return
image
.
convert
(
image_mode
)
class
ImageFetchAiohttp
:
async
def
async_fetch_image
(
image_url
:
str
,
aiohttp_client
:
Optional
[
aiohttp
.
ClientSession
]
=
None
*
,
image_mode
:
str
=
"RGB"
)
->
Image
.
Image
:
@
classmethod
"""
def
get_aiohttp_client
(
cls
)
->
aiohttp
.
ClientSession
:
Asynchronously load a PIL image from a HTTP or base64 data URL.
if
cls
.
aiohttp_client
is
None
:
timeout
=
aiohttp
.
ClientTimeout
(
total
=
VLLM_IMAGE_FETCH_TIMEOUT
)
connector
=
aiohttp
.
TCPConnector
()
cls
.
aiohttp_client
=
aiohttp
.
ClientSession
(
timeout
=
timeout
,
connector
=
connector
)
return
cls
.
aiohttp_client
@
classmethod
async
def
fetch_image
(
cls
,
image_url
:
str
,
*
,
image_mode
:
str
=
"RGB"
,
)
->
Image
.
Image
:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if
image_url
.
startswith
(
'http'
):
_validate_remote_url
(
image_url
,
name
=
"image_url"
)
client
=
cls
.
get_aiohttp_client
()
headers
=
_get_request_headers
()
async
with
client
.
get
(
url
=
image_url
,
headers
=
headers
)
as
response
:
By default, the image is converted into RGB format.
response
.
raise_for_status
()
"""
image_raw
=
await
response
.
read
()
if
image_url
.
startswith
(
'http'
):
image
=
_load_image_from_bytes
(
image_raw
)
image_raw
=
await
global_http_connection
.
async_get_bytes
(
image_url
,
timeout
=
VLLM_IMAGE_FETCH_TIMEOUT
)
image
=
_load_image_from_bytes
(
image_raw
)
elif
image_url
.
startswith
(
'data:image'
):
elif
image_url
.
startswith
(
'data:image'
):
image
=
_load_image_from_data_url
(
image_url
)
image
=
_load_image_from_data_url
(
image_url
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Invalid 'image_url': A valid 'image_url' must start "
"Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'."
)
"with either 'data:image' or 'http'."
)
return
image
.
convert
(
image_mode
)
return
image
.
convert
(
image_mode
)
async
def
async_get_and_parse_image
(
image_url
:
str
)
->
MultiModalDataDict
:
async
def
async_get_and_parse_image
(
image_url
:
str
)
->
MultiModalDataDict
:
image
=
await
ImageFetchAiohttp
.
fetch_image
(
image_url
)
image
=
await
async_
fetch_image
(
image_url
)
return
{
"image"
:
image
}
return
{
"image"
:
image
}
...
...
vllm/platforms/__init__.py
View file @
500b93c8
...
@@ -2,7 +2,9 @@ from typing import Optional
...
@@ -2,7 +2,9 @@ from typing import Optional
import
torch
import
torch
from
.interface
import
Platform
,
PlatformEnum
from
vllm.utils
import
is_tpu
from
.interface
import
Platform
,
PlatformEnum
,
UnspecifiedPlatform
current_platform
:
Optional
[
Platform
]
current_platform
:
Optional
[
Platform
]
...
@@ -12,7 +14,10 @@ if torch.version.cuda is not None:
...
@@ -12,7 +14,10 @@ if torch.version.cuda is not None:
elif
torch
.
version
.
hip
is
not
None
:
elif
torch
.
version
.
hip
is
not
None
:
from
.rocm
import
RocmPlatform
from
.rocm
import
RocmPlatform
current_platform
=
RocmPlatform
()
current_platform
=
RocmPlatform
()
elif
is_tpu
():
from
.tpu
import
TpuPlatform
current_platform
=
TpuPlatform
()
else
:
else
:
current_platform
=
None
current_platform
=
UnspecifiedPlatform
()
__all__
=
[
'Platform'
,
'PlatformEnum'
,
'current_platform'
]
__all__
=
[
'Platform'
,
'PlatformEnum'
,
'current_platform'
]
vllm/platforms/interface.py
View file @
500b93c8
import
enum
import
enum
from
typing
import
Tuple
from
typing
import
Tuple
import
torch
class
PlatformEnum
(
enum
.
Enum
):
class
PlatformEnum
(
enum
.
Enum
):
CUDA
=
enum
.
auto
()
CUDA
=
enum
.
auto
()
ROCM
=
enum
.
auto
()
ROCM
=
enum
.
auto
()
TPU
=
enum
.
auto
()
UNSPECIFIED
=
enum
.
auto
()
class
Platform
:
class
Platform
:
...
@@ -16,6 +20,23 @@ class Platform:
...
@@ -16,6 +20,23 @@ class Platform:
def
is_rocm
(
self
)
->
bool
:
def
is_rocm
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
ROCM
return
self
.
_enum
==
PlatformEnum
.
ROCM
def
is_tpu
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
TPU
@
staticmethod
@
staticmethod
def
get_device_capability
(
device_id
:
int
=
0
)
->
Tuple
[
int
,
int
]:
def
get_device_capability
(
device_id
:
int
=
0
)
->
Tuple
[
int
,
int
]:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
def
inference_mode
():
"""A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU
do not support `torch.inference_mode`. In such a case, they will fall
back to `torch.no_grad` by overriding this method.
"""
return
torch
.
inference_mode
(
mode
=
True
)
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
vllm/platforms/tpu.py
0 → 100644
View file @
500b93c8
from
typing
import
Tuple
import
torch
from
.interface
import
Platform
,
PlatformEnum
class
TpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
TPU
@
staticmethod
def
get_device_capability
(
device_id
:
int
=
0
)
->
Tuple
[
int
,
int
]:
raise
RuntimeError
(
"TPU does not have device capability."
)
@
staticmethod
def
inference_mode
():
return
torch
.
no_grad
()
vllm/sampling_params.py
View file @
500b93c8
...
@@ -8,7 +8,6 @@ import torch
...
@@ -8,7 +8,6 @@ import torch
from
pydantic
import
Field
from
pydantic
import
Field
from
typing_extensions
import
Annotated
from
typing_extensions
import
Annotated
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -189,18 +188,6 @@ class SamplingParams:
...
@@ -189,18 +188,6 @@ class SamplingParams:
self
.
_verify_args
()
self
.
_verify_args
()
if
self
.
use_beam_search
:
if
self
.
use_beam_search
:
# Lazy import to avoid circular imports.
from
vllm.usage.usage_lib
import
set_runtime_usage_data
set_runtime_usage_data
(
"use_beam_search"
,
True
)
if
not
envs
.
VLLM_NO_DEPRECATION_WARNING
:
logger
.
warning
(
"[IMPORTANT] We plan to discontinue the support for beam "
"search in the next major release. Please refer to "
"https://github.com/vllm-project/vllm/issues/6226 for "
"more information. Set VLLM_NO_DEPRECATION_WARNING=1 to "
"suppress this warning."
)
self
.
_verify_beam_search
()
self
.
_verify_beam_search
()
else
:
else
:
self
.
_verify_non_beam_search
()
self
.
_verify_non_beam_search
()
...
...
vllm/sequence.py
View file @
500b93c8
...
@@ -5,7 +5,8 @@ import math
...
@@ -5,7 +5,8 @@ import math
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
)
import
torch
import
torch
...
@@ -438,7 +439,7 @@ class SequenceGroup:
...
@@ -438,7 +439,7 @@ class SequenceGroup:
embeddings
:
Optional
[
List
[
float
]]
=
None
,
embeddings
:
Optional
[
List
[
float
]]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
trace_headers
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
...
@@ -457,24 +458,25 @@ class SequenceGroup:
...
@@ -457,24 +458,25 @@ class SequenceGroup:
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
encoder_seq
=
encoder_seq
self
.
encoder_seq
=
encoder_seq
self
.
trace_headers
=
trace_headers
self
.
trace_headers
=
trace_headers
self
.
_first_seq
=
next
(
iter
(
self
.
seqs_dict
.
values
()))
@
property
@
property
def
prompt
(
self
)
->
Optional
[
str
]:
def
prompt
(
self
)
->
Optional
[
str
]:
# All sequences in the group should have the same prompt.
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
# We use the prompt of an arbitrary sequence.
return
next
(
iter
(
self
.
seqs_dict
.
values
()))
.
prompt
return
self
.
_first_seq
.
prompt
@
property
@
property
def
prompt_token_ids
(
self
)
->
List
[
int
]:
def
prompt_token_ids
(
self
)
->
List
[
int
]:
# All sequences in the group should have the same prompt.
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
# We use the prompt of an arbitrary sequence.
return
next
(
iter
(
self
.
seqs_dict
.
values
()))
.
prompt_token_ids
return
self
.
_first_seq
.
prompt_token_ids
@
property
@
property
def
multi_modal_data
(
self
)
->
"MultiModalDataDict"
:
def
multi_modal_data
(
self
)
->
"MultiModalDataDict"
:
# All sequences in the group should have the same multi-modal data.
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
# We use the multi-modal data of an arbitrary sequence.
return
next
(
iter
(
self
.
seqs_dict
.
values
()))
.
multi_modal_data
return
self
.
_first_seq
.
multi_modal_data
@
property
@
property
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
...
...
vllm/spec_decode/batch_expansion.py
View file @
500b93c8
...
@@ -4,7 +4,8 @@ from typing import Iterator, List, Tuple
...
@@ -4,7 +4,8 @@ from typing import Iterator, List, Tuple
import
torch
import
torch
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
get_all_seq_ids
)
SequenceGroupMetadata
,
SequenceGroupState
,
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
...
@@ -292,6 +293,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -292,6 +293,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
for
data
in
new_seq_data_dict
.
values
():
for
data
in
new_seq_data_dict
.
values
():
data
.
update_num_computed_tokens
(
data
.
get_len
()
-
1
)
data
.
update_num_computed_tokens
(
data
.
get_len
()
-
1
)
if
(
seq_group_metadata
.
state
is
not
None
and
seq_group_metadata
.
state
.
generator
is
not
None
):
generator
=
torch
.
Generator
(
device
=
seq_group_metadata
.
state
.
generator
.
device
)
generator
.
set_state
(
seq_group_metadata
.
state
.
generator
.
get_state
())
state
=
SequenceGroupState
(
generator
=
generator
)
else
:
state
=
None
return
SequenceGroupMetadata
(
return
SequenceGroupMetadata
(
request_id
=
seq_group_metadata
.
request_id
,
request_id
=
seq_group_metadata
.
request_id
,
is_prompt
=
seq_group_metadata
.
is_prompt
,
is_prompt
=
seq_group_metadata
.
is_prompt
,
...
@@ -302,6 +312,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -302,6 +312,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
},
},
lora_request
=
None
,
lora_request
=
None
,
token_chunk_size
=
1
,
token_chunk_size
=
1
,
state
=
state
,
)
)
def
_split_scoring_output
(
def
_split_scoring_output
(
...
...
vllm/spec_decode/draft_model_runner.py
View file @
500b93c8
...
@@ -2,17 +2,33 @@ from typing import List, Optional
...
@@ -2,17 +2,33 @@ from typing import List, Optional
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
try
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
except
ModuleNotFoundError
:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from
vllm.attention.backends.rocm_flash_attn
import
(
ROCmFlashAttentionMetadata
as
FlashAttentionMetadata
)
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
S
equenceGroupMetadata
)
S
amplerOutput
)
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
ModelRunner
)
ModelRunner
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# A flag to enable debug prints for the updated input tensors
# before each step.
debug_advance_input
=
False
# A flag to allow GPU advance step for draft model runner.
# Set to False for debugging.
allow_gpu_advance_step
=
True
class
TP1DraftModelRunner
(
ModelRunner
):
class
TP1DraftModelRunner
(
ModelRunner
):
"""Specialized model runner for speculative decoding draft model.
"""Specialized model runner for speculative decoding draft model.
...
@@ -21,18 +37,9 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -21,18 +37,9 @@ class TP1DraftModelRunner(ModelRunner):
we could get rid of most CPU-GPU synchronization and data transfer
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
overheads by keeping model input and output tensors on GPU all the time.
This runner is still under development so there's no performance gain
TODOs:
at this moment. Currently we adopt a temporary solution that caches the
1. Currently supports only flash-attn, add support for other attn_backends.
seq_group_metadata_list for multi-step execution, so that we can
2. Support TP > 1 (this requires some designs because we do not expect
leverage existing prepare_model_input to be compatible with the current
execution flow, but we plan to remove this cache and avoid calling
prepare_model_input in execute_model at all.
The detail development plan includes:
1. Use "update_model_input" to update existing model_input without
creating a new one.
2. Improve the performance of "update_model_input" with a GPU kernel.
3. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
any broadcasting inside execute_model).
"""
"""
...
@@ -71,51 +78,156 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -71,51 +78,156 @@ class TP1DraftModelRunner(ModelRunner):
return_hidden_states
=
return_hidden_states
,
return_hidden_states
=
return_hidden_states
,
)
)
# TODO: Remove this cache when we are able to update model_input
def
_update_flash_attn_metadata
(
self
,
attn_metadata
,
num_seqs
,
# directly in advance_step.
num_queries
):
self
.
cached_seq_group_metadata_list
:
Optional
[
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
List
[
SequenceGroupMetadata
]]
=
None
def
prepare_model_input
(
if
num_seqs
!=
num_queries
:
self
,
assert
num_seqs
>
num_queries
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
assert
attn_metadata
.
use_cuda_graph
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
assert
attn_metadata
.
num_prefills
==
0
)
->
ModelInputForGPUWithSamplingMetadata
:
assert
attn_metadata
.
num_prefill_tokens
==
0
"""A temporary solution that caches the seq_group_metadata_list
assert
attn_metadata
.
num_decode_tokens
==
num_seqs
for multi-step execution.
assert
attn_metadata
.
slot_mapping
.
shape
==
(
num_seqs
,
)
TODO: In-place update model_input and remove this function.
"""
assert
len
(
attn_metadata
.
seq_lens
)
==
num_seqs
self
.
cached_seq_group_metadata_list
=
seq_group_metadata_list
assert
attn_metadata
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
return
super
().
prepare_model_input
(
assert
attn_metadata
.
max_query_len
==
1
seq_group_metadata_list
,
assert
attn_metadata
.
max_prefill_seq_len
==
0
finished_requests_ids
=
finished_requests_ids
)
assert
attn_metadata
.
max_decode_seq_len
==
max
(
attn_metadata
.
seq_lens
)
assert
attn_metadata
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
attn_metadata
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
attn_metadata
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
def
update_model_input
(
assert
attn_metadata
.
block_tables
.
shape
[
0
]
==
num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
attn_metadata
.
seq_lens
[
i
]
+=
1
attn_metadata
.
max_decode_seq_len
=
max
(
attn_metadata
.
seq_lens
)
def
_update_sampling_metadata
(
self
,
sampling_metadata
,
num_seqs
,
num_queries
):
assert
sampling_metadata
.
num_prompts
==
0
assert
len
(
sampling_metadata
.
seq_groups
)
==
num_queries
assert
sampling_metadata
.
selected_token_indices
.
shape
==
(
num_queries
,
)
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for
i
in
range
(
num_queries
):
seq_group
=
sampling_metadata
.
seq_groups
[
i
]
assert
seq_group
.
is_prompt
is
False
# No prompt
assert
seq_group
.
prompt_logprob_indices
==
[]
# No prompt
assert
seq_group
.
sample_indices
==
[
i
]
# Simple
assert
seq_group
.
seq_len
is
None
# Decode
assert
seq_group
.
query_len
is
None
# Decode
def
_gpu_advance_step
(
self
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
self
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
last_output
:
SamplerOutput
last_output
:
SamplerOutput
)
->
ModelInputForGPUWithSamplingMetadata
:
)
->
ModelInputForGPUWithSamplingMetadata
:
"""Prepare the model inputs for the next step.
# Currently, we expect "decode mode" only
TODO: In-place update model_input instead of calling
assert
not
model_input
.
is_prompt
prepare_model_input.
# Get num_seqs
num_seqs
=
len
(
model_input
.
seq_lens
)
num_queries
=
len
(
model_input
.
query_lens
)
# Get output tokens GPU tensor
sampled_token_ids
=
last_output
.
sampled_token_ids
assert
sampled_token_ids
is
not
None
# Update attn_metadata
attn_metadata
=
model_input
.
attn_metadata
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
self
.
_update_flash_attn_metadata
(
attn_metadata
,
num_seqs
,
num_queries
)
# Update GPU tensors
ops
.
advance_step
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
self
.
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
seq_lens
=
attn_metadata
.
seq_lens_tensor
,
slot_mapping
=
attn_metadata
.
slot_mapping
,
block_tables
=
attn_metadata
.
block_tables
)
# Update sampling_metadata
sampling_metadata
=
model_input
.
sampling_metadata
self
.
_update_sampling_metadata
(
sampling_metadata
,
num_seqs
,
num_queries
)
# Create new input
new_model_input
=
self
.
_model_input_cls
(
input_tokens
=
model_input
.
input_tokens
,
input_positions
=
model_input
.
input_positions
,
attn_metadata
=
attn_metadata
,
seq_lens
=
attn_metadata
.
seq_lens
,
query_lens
=
model_input
.
query_lens
,
lora_mapping
=
model_input
.
lora_mapping
,
lora_requests
=
model_input
.
lora_requests
,
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
,
sampling_metadata
=
model_input
.
sampling_metadata
,
is_prompt
=
False
,
)
# Ensure we skip CPU samples
assert
new_model_input
.
sampling_metadata
.
skip_sampler_cpu_output
is
True
# We can reuse sampling tensors since every decode iteration is the same
new_model_input
.
sampling_metadata
.
reuse_sampling_tensors
=
True
if
debug_advance_input
:
logger
.
debug
(
"NEW INPUT: "
)
logger
.
debug
(
" input_tokens = %s"
,
new_model_input
.
input_tokens
)
logger
.
debug
(
" input_positions = %s"
,
new_model_input
.
input_positions
)
logger
.
debug
(
" seq_lens = %d"
,
new_model_input
.
seq_lens
)
logger
.
debug
(
" query_lens = %d"
,
new_model_input
.
query_lens
)
logger
.
debug
(
" attn_metadata:"
)
logger
.
debug
(
" seq_lens_tensor: %s"
,
attn_metadata
.
seq_lens_tensor
)
logger
.
debug
(
" slot_mapping: %s"
,
attn_metadata
.
slot_mapping
)
logger
.
debug
(
" block_tables: %s"
,
attn_metadata
.
block_tables
)
return
new_model_input
def
supports_gpu_multi_step
(
self
,
execute_model_req
:
ExecuteModelRequest
):
"""Determines if draft_model_runner GPU multi-step can be used.
Currently required conditions are:
1. Only decodes
2. Only flash-attn
3. No LORA
4. No prompt_adapter_config
"""
"""
if
not
allow_gpu_advance_step
:
return
False
# Append the output token to the sequence data.
# We allow multi-step GPU only in decode mode
assert
self
.
cached_seq_group_metadata_list
is
not
None
for
seq_group
in
execute_model_req
.
seq_group_metadata_list
:
for
seq_group_metadata
,
sequence_group_outputs
in
zip
(
if
seq_group
.
is_prompt
:
self
.
cached_seq_group_metadata_list
,
last_output
.
outputs
):
return
False
seq_group_metadata
.
is_prompt
=
False
for
seq_output
in
sequence_group_outputs
.
samples
:
# TODO: Add support for other attn backends
seq
=
seq_group_metadata
.
seq_data
[
seq_output
.
parent_seq_id
]
if
self
.
attn_backend
.
get_name
()
!=
"flash-attn"
:
return
False
token_id
=
seq_output
.
output_token
# TODO: Add support for LORA
token_logprob
=
seq_output
.
logprobs
[
token_id
]
if
self
.
lora_config
:
return
False
seq
.
append_token_id
(
token_id
,
token_logprob
.
logprob
)
# TODO: Add soft-tuning prompt adapter support
seq
.
update_num_computed_tokens
(
1
)
if
self
.
prompt_adapter_config
:
return
False
return
self
.
prepare_model_input
(
self
.
cached_seq_group_metadata_list
)
return
True
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
...
@@ -125,42 +237,86 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -125,42 +237,86 @@ class TP1DraftModelRunner(ModelRunner):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
)
->
Optional
[
List
[
SamplerOutput
]]:
# Since we do not broadcast data inside execute_model anymore,
"""Executes num_steps forward passes with advacement of input tensors
# we need to figure out the best way to support TP > 1 in this
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if
not
self
.
is_driver_worker
:
raise
ValueError
(
"TP1DraftModelRunner only supports TP=1."
)
if
self
.
lora_config
:
Optimizations used:
assert
model_input
.
lora_requests
is
not
None
1. Input tensors are updated on the GPU directly
assert
model_input
.
lora_mapping
is
not
None
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
self
.
set_active_loras
(
model_input
.
lora_requests
,
them since we do batch expansion later that uses GPU outputs)
model_input
.
lora_mapping
)
3. Reuses sampling tensors (since we run only decodes and they have
a repeating sampling logic)
"""
if
self
.
prompt_adapter_config
:
# When num_steps == 1, we execute the fallback here for the GPU
assert
model_input
.
prompt_adapter_requests
is
not
None
# advance_step, which runs prepare_inputs on CPU and for each spec
assert
model_input
.
prompt_adapter_mapping
is
not
None
# iteration invokes this function only once
self
.
set_active_prompt_adapters
(
# (Look at multi-step-worker code)
model_input
.
prompt_adapter_requests
,
is_fallback
=
num_steps
==
1
model_input
.
prompt_adapter_mapping
)
if
not
is_fallback
:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if
not
self
.
is_driver_worker
:
raise
ValueError
(
"TP1DraftModelRunner only supports TP=1."
)
# Sanity
if
self
.
lora_config
is
not
None
:
raise
ValueError
(
"TP1DraftModelRunner has no support for LORA"
)
if
self
.
prompt_adapter_config
is
not
None
:
raise
ValueError
(
"TP1DraftModelRunner has no support for "
"prompt_adapter_config"
)
if
model_input
.
multi_modal_kwargs
:
raise
ValueError
(
"TP1DraftModelRunner has no support for multi_modal_kwargs"
)
else
:
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
if
self
.
prompt_adapter_config
:
assert
model_input
.
prompt_adapter_requests
is
not
None
assert
model_input
.
prompt_adapter_mapping
is
not
None
self
.
set_active_prompt_adapters
(
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_mapping
)
# Detect exec mode
assert
model_input
.
attn_metadata
is
not
None
use_cuda_graph
=
False
if
model_input
.
attn_metadata
.
num_prefills
>
0
:
# In this case, execute_model(..) was called directly
if
num_steps
>
1
:
raise
ValueError
(
"execute_model(..) of draft_model_runner can be called "
"directly only with a single-step prefill"
)
else
:
# We can skip CPU samples for spec token generation.
# (We do allow CPU samples for num_steps == 1 to support the
# fallback case, where supports_gpu_multi_step(..) does not pass)
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
=
(
not
is_fallback
)
# Attn attr defines if we use cuda graphs
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
# Get model
if
use_cuda_graph
:
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_executable
=
(
self
.
graph_runners
[
model_input
.
virtual_engine
]
[
graph_batch_size
])
else
:
model_executable
=
self
.
model
virtual_engine
=
model_input
.
virtual_engine
outputs
:
List
[
SamplerOutput
]
=
[]
outputs
:
List
[
SamplerOutput
]
=
[]
for
step
in
range
(
num_steps
):
for
step
in
range
(
num_steps
):
# Currently cuda graph is only supported by the decode phase.
assert
model_input
.
attn_metadata
is
not
None
prefill_meta
=
model_input
.
attn_metadata
.
prefill_metadata
decode_meta
=
model_input
.
attn_metadata
.
decode_metadata
if
prefill_meta
is
None
and
decode_meta
.
use_cuda_graph
:
assert
model_input
.
input_tokens
is
not
None
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_executable
=
(
self
.
graph_runners
[
virtual_engine
][
graph_batch_size
])
else
:
model_executable
=
self
.
model
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
# Run model
hidden_states
=
model_executable
(
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
...
@@ -181,8 +337,8 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -181,8 +337,8 @@ class TP1DraftModelRunner(ModelRunner):
sampling_metadata
=
model_input
.
sampling_metadata
,
sampling_metadata
=
model_input
.
sampling_metadata
,
))
))
# Prepare
the
inputs for the next step
.
# Prepare inputs for the next step
if
step
!=
num_steps
-
1
:
if
step
!=
num_steps
-
1
:
model_input
=
self
.
update_model_input
(
model_input
,
outputs
[
-
1
])
model_input
=
self
.
_gpu_advance_step
(
model_input
,
outputs
[
-
1
])
return
outputs
return
outputs
vllm/spec_decode/interfaces.py
View file @
500b93c8
...
@@ -22,6 +22,9 @@ class SpeculativeProposals:
...
@@ -22,6 +22,9 @@ class SpeculativeProposals:
# The valid length of each proposal; can be zero.
# The valid length of each proposal; can be zero.
proposal_lens
:
torch
.
Tensor
proposal_lens
:
torch
.
Tensor
# A flag to mark that there's no available proposals
no_proposals
:
bool
=
False
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
f
"SpeculativeProposals("
return
(
f
"SpeculativeProposals("
f
"proposal_token_ids=
{
self
.
proposal_token_ids
}
, "
f
"proposal_token_ids=
{
self
.
proposal_token_ids
}
, "
...
...
vllm/spec_decode/metrics.py
View file @
500b93c8
...
@@ -145,6 +145,10 @@ class AsyncMetricsCollector:
...
@@ -145,6 +145,10 @@ class AsyncMetricsCollector:
"""
"""
ready_event
.
synchronize
()
ready_event
.
synchronize
()
# update time of last collection
self
.
_last_metrics_collect_time
=
self
.
_timer
()
accepted_tokens
=
self
.
_aggregate_num_accepted_tokens
.
item
()
accepted_tokens
=
self
.
_aggregate_num_accepted_tokens
.
item
()
emitted_tokens
=
self
.
_aggregate_num_emitted_tokens
.
item
()
emitted_tokens
=
self
.
_aggregate_num_emitted_tokens
.
item
()
draft_tokens
=
self
.
_aggregate_num_draft_tokens
draft_tokens
=
self
.
_aggregate_num_draft_tokens
...
...
vllm/spec_decode/multi_step_worker.py
View file @
500b93c8
...
@@ -43,7 +43,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
...
@@ -43,7 +43,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
)
)
def
set_include_gpu_probs_tensor
(
self
)
->
None
:
def
set_include_gpu_probs_tensor
(
self
)
->
None
:
# Need include_gpu_probs_tensor for
m
ulti
_s
tep
_w
orker
# Need include_gpu_probs_tensor for
M
ulti
S
tep
W
orker
self
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
self
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -67,14 +67,23 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
...
@@ -67,14 +67,23 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
expanded_request
,
indices_of_seq_with_bonus_tokens
=
\
expanded_request
,
indices_of_seq_with_bonus_tokens
=
\
self
.
_expand_execute_model_request
(
self
.
_expand_execute_model_request
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
# Run model sample_len times.
# Run model sample_len times.
model_outputs
:
List
[
SamplerOutput
]
=
[]
model_outputs
:
List
[
SamplerOutput
]
=
[]
if
isinstance
(
self
.
model_runner
,
TP1DraftModelRunner
):
if
isinstance
(
self
.
model_runner
,
TP1DraftModelRunner
)
and
self
.
model_runner
.
supports_gpu_multi_step
(
expanded_request
):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request
.
num_steps
=
sample_len
expanded_request
.
num_steps
=
sample_len
model_outputs
=
self
.
execute_model
(
model_outputs
=
self
.
execute_model
(
execute_model_req
=
expanded_request
)
execute_model_req
=
expanded_request
)
else
:
else
:
# TODO: Remove this branch once DraftModelRunner supports TP>1.
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for
_
in
range
(
sample_len
):
for
_
in
range
(
sample_len
):
model_output
:
List
[
SamplerOutput
]
=
super
().
execute_model
(
model_output
:
List
[
SamplerOutput
]
=
super
().
execute_model
(
execute_model_req
=
expanded_request
)
execute_model_req
=
expanded_request
)
...
@@ -171,7 +180,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
...
@@ -171,7 +180,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
outputs
=
[
outputs
=
[
expanded_batch_output
.
outputs
[
i
]
expanded_batch_output
.
outputs
[
i
]
for
i
in
output_indices_to_retain
for
i
in
output_indices_to_retain
],
]
if
len
(
expanded_batch_output
.
outputs
)
>
0
else
[]
,
sampled_token_probs
=
(
sampled_token_probs
=
(
expanded_batch_output
.
expanded_batch_output
.
sampled_token_probs
[
output_indices_to_retain
]
sampled_token_probs
[
output_indices_to_retain
]
...
...
vllm/spec_decode/ngram_worker.py
View file @
500b93c8
...
@@ -13,7 +13,7 @@ from vllm.worker.worker_base import LoraNotSupportedWorkerBase
...
@@ -13,7 +13,7 @@ from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class
NGramWorker
(
NonLLMProposerWorkerBase
,
LoraNotSupportedWorkerBase
):
class
NGramWorker
(
NonLLMProposerWorkerBase
,
LoraNotSupportedWorkerBase
):
"""NGramWorker provides a light drafter without need for model.
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding,
Current NGramWorker only implement
s
prompt lookup decoding,
and in future we may also do RAG type drafter and other scenarios
and in future we may also do RAG type drafter and other scenarios
which don't rely on LLM model to give proposals.
which don't rely on LLM model to give proposals.
"""
"""
...
@@ -37,7 +37,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
...
@@ -37,7 +37,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
load_model
=
lambda
*
args
,
**
kwargs
:
None
self
.
load_model
=
lambda
*
args
,
**
kwargs
:
None
# Current only support Top1Proposer
# Current
NGramWorker
only support
s
Top1Proposer
self
.
_proposer
=
Top1Proposer
(
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
device
=
self
.
device
,
device
=
self
.
device
,
...
...
vllm/spec_decode/proposer_worker_base.py
View file @
500b93c8
...
@@ -24,7 +24,7 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
...
@@ -24,7 +24,7 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
)
->
Tuple
[
Optional
[
List
[
SamplerOutput
]],
bool
]:
)
->
Tuple
[
Optional
[
List
[
SamplerOutput
]],
bool
]:
raise
NotImplementedError
raise
NotImplementedError
def
set_include_gpu_probs_tensor
(
self
):
def
set_include_gpu_probs_tensor
(
self
)
->
None
:
"""Implementation optional"""
"""Implementation optional"""
pass
pass
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
500b93c8
...
@@ -9,12 +9,12 @@ from vllm.distributed.communication_op import broadcast_tensor_dict
...
@@ -9,12 +9,12 @@ from vllm.distributed.communication_op import broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
)
SpecDecodeBaseSampler
,
SpecDecodeStochasticBaseSampler
)
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
TypicalAcceptanceSampler
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
HiddenStates
,
SamplerOutput
,
SequenceGroupMetadata
,
HiddenStates
,
SamplerOutput
,
SequenceGroupMetadata
,
get_all_seq_ids_and_request_ids
)
get_all_seq_ids
,
get_all_seq_ids_and_request_ids
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
...
@@ -26,6 +26,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
...
@@ -26,6 +26,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.smaller_tp_proposer_worker
import
SmallerTpProposerWorker
from
vllm.spec_decode.smaller_tp_proposer_worker
import
SmallerTpProposerWorker
from
vllm.spec_decode.target_model_runner
import
TargetModelRunner
from
vllm.spec_decode.util
import
(
create_sequence_group_output
,
from
vllm.spec_decode.util
import
(
create_sequence_group_output
,
get_all_num_logprobs
,
get_all_num_logprobs
,
get_sampled_token_logprobs
,
nvtx_range
,
get_sampled_token_logprobs
,
nvtx_range
,
...
@@ -44,9 +45,15 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
...
@@ -44,9 +45,15 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
speculative_config
:
SpeculativeConfig
=
kwargs
.
get
(
"speculative_config"
)
speculative_config
:
SpeculativeConfig
=
kwargs
.
get
(
"speculative_config"
)
assert
speculative_config
is
not
None
assert
speculative_config
is
not
None
draft_worker_kwargs
=
kwargs
.
copy
()
kwargs
[
"model_runner_cls"
]
=
TargetModelRunner
target_worker
=
Worker
(
*
args
,
**
kwargs
)
target_worker
=
Worker
(
*
args
,
**
kwargs
)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker
.
model_runner
.
disable_logprobs
=
\
speculative_config
.
disable_logprobs
draft_worker_kwargs
=
kwargs
.
copy
()
# Override draft-model specific worker args.
# Override draft-model specific worker args.
draft_worker_kwargs
.
update
(
draft_worker_kwargs
.
update
(
model_config
=
speculative_config
.
draft_model_config
,
model_config
=
speculative_config
.
draft_model_config
,
...
@@ -67,7 +74,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
...
@@ -67,7 +74,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_threshold
=
speculative_config
.
typical_acceptance_sampler_posterior_threshold
=
speculative_config
.
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_alpha
=
speculative_config
.
typical_acceptance_sampler_posterior_alpha
=
speculative_config
.
typical_acceptance_sampler_posterior_alpha
)
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
speculative_config
.
disable_logprobs
)
return
spec_decode_worker
return
spec_decode_worker
...
@@ -107,8 +115,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -107,8 +115,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_token_acceptance_method
:
str
,
draft_token_acceptance_method
:
str
,
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
)
->
"SpecDecodeWorker"
:
)
->
"SpecDecodeWorker"
:
allow_zero_draft_token_step
=
True
ngram_prompt_lookup_max
=
(
ngram_prompt_lookup_max
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_max"
))
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_max"
))
ngram_prompt_lookup_min
=
(
ngram_prompt_lookup_min
=
(
...
@@ -133,6 +143,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -133,6 +143,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if
draft_tp
==
1
:
if
draft_tp
==
1
:
draft_worker_kwargs
[
draft_worker_kwargs
[
"model_runner_cls"
]
=
TP1DraftModelRunner
"model_runner_cls"
]
=
TP1DraftModelRunner
else
:
allow_zero_draft_token_step
=
False
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
proposer_worker
=
SmallerTpProposerWorker
.
maybe_wrap_worker
(
proposer_worker
=
SmallerTpProposerWorker
.
maybe_wrap_worker
(
...
@@ -155,18 +167,23 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -155,18 +167,23 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger
.
info
(
"Configuring SpecDecodeWorker with sampler=%s"
,
logger
.
info
(
"Configuring SpecDecodeWorker with sampler=%s"
,
type
(
spec_decode_sampler
))
type
(
spec_decode_sampler
))
return
SpecDecodeWorker
(
proposer_worker
,
return
SpecDecodeWorker
(
scorer_worker
,
proposer_worker
,
disable_by_batch_size
=
disable_by_batch_size
,
scorer_worker
,
spec_decode_sampler
=
spec_decode_sampler
)
disable_logprobs
=
disable_logprobs
,
disable_by_batch_size
=
disable_by_batch_size
,
spec_decode_sampler
=
spec_decode_sampler
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
)
def
__init__
(
def
__init__
(
self
,
self
,
proposer_worker
:
ProposerWorkerBase
,
proposer_worker
:
ProposerWorkerBase
,
scorer_worker
:
WorkerBase
,
scorer_worker
:
WorkerBase
,
spec_decode_sampler
:
SpecDecodeBaseSampler
,
spec_decode_sampler
:
SpecDecodeBaseSampler
,
disable_logprobs
:
bool
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
):
):
"""
"""
Create a SpecDecodeWorker.
Create a SpecDecodeWorker.
...
@@ -183,15 +200,22 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -183,15 +200,22 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
types of sampler namely RejectionSampler and
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
disable_by_batch_size: If the batch size is larger than this,
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
for testing purposes.
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
"""
"""
self
.
proposer_worker
=
proposer_worker
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
self
.
scorer_worker
=
scorer_worker
self
.
disable_by_batch_size
=
disable_by_batch_size
or
float
(
"inf"
)
self
.
disable_by_batch_size
=
disable_by_batch_size
or
float
(
"inf"
)
self
.
spec_decode_sampler
=
spec_decode_sampler
self
.
spec_decode_sampler
=
spec_decode_sampler
self
.
_allow_zero_draft_token_step
=
allow_zero_draft_token_step
self
.
_metrics
=
AsyncMetricsCollector
(
self
.
_metrics
=
AsyncMetricsCollector
(
self
.
spec_decode_sampler
self
.
spec_decode_sampler
)
if
metrics_collector
is
None
else
metrics_collector
)
if
metrics_collector
is
None
else
metrics_collector
...
@@ -206,12 +230,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -206,12 +230,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
probs_dtype
=
self
.
spec_decode_sampler
.
probs_dtype
self
.
probs_dtype
=
self
.
spec_decode_sampler
.
probs_dtype
self
.
token_id_dtype
=
self
.
spec_decode_sampler
.
token_id_dtype
self
.
token_id_dtype
=
self
.
spec_decode_sampler
.
token_id_dtype
# Lazy initia
z
liation.
# Lazy initiali
z
ation.
self
.
scorer
:
SpeculativeScorer
self
.
scorer
:
SpeculativeScorer
# Hidden states from target model to pass to proposer
# Hidden states from target model to pass to proposer
# in the subsequent step.
# in the subsequent step.
self
.
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
self
.
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
self
.
_disable_logprobs
=
disable_logprobs
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
"""Initialize both scorer and proposer models.
"""Initialize both scorer and proposer models.
...
@@ -347,7 +372,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -347,7 +372,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
)
==
0
or
disable_all_speculation
:
)
==
0
or
disable_all_speculation
:
return
self
.
_run_no_spec
(
execute_model_req
,
return
self
.
_run_no_spec
(
execute_model_req
,
skip_proposer
=
disable_all_speculation
)
skip_proposer
=
disable_all_speculation
)
return
self
.
_run_speculative_decoding_step
(
execute_model_req
,
return
self
.
_run_speculative_decoding_step
(
execute_model_req
,
num_lookahead_slots
)
num_lookahead_slots
)
...
@@ -381,6 +405,42 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -381,6 +405,42 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# this state within spec decode worker.
# this state within spec decode worker.
seq_group_metadata
.
num_speculative_tokens
=
0
seq_group_metadata
.
num_speculative_tokens
=
0
def
_serialize_sampler_output_no_logprobs
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sampler_output
:
SamplerOutput
)
->
SamplerOutput
:
"""
Creates and returns a `SamplerOutput` with only the sampled token IDs
being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
All other parameters in `CompletionSequenceGroupOutput` related to log
probabilities are skipped.
Args:
execute_model_req (ExecuteModelRequest): The model request that
was executed.
sampler_output (SamplerOutput): The output from the sampler with
only GPU tensors populated.
Returns:
SamplerOutput: A new `SamplerOutput` instance containing a list of
`CompletionSequenceGroupOutput` objects with only sampled token
IDs populated.
"""
seq_ids
=
get_all_seq_ids
(
execute_model_req
.
seq_group_metadata_list
)
sampled_token_ids_list
=
sampler_output
.
sampled_token_ids
.
tolist
()
completion_seq_group_output_list
:
List
[
CompletionSequenceGroupOutput
]
=
[]
for
index
,
seq_id
in
enumerate
(
seq_ids
):
completion_seq_group_output_list
.
append
(
create_sequence_group_output
(
token_id
=
sampled_token_ids_list
[
index
][
0
],
token_id_logprob_rank
=-
1
,
token_id_logprob
=
0.0
,
seq_id
=
seq_id
,
topk_token_ids
=
[],
topk_logprobs
=
[],
))
return
SamplerOutput
(
outputs
=
completion_seq_group_output_list
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
skip_proposer
:
bool
)
->
List
[
SamplerOutput
]:
skip_proposer
:
bool
)
->
List
[
SamplerOutput
]:
...
@@ -407,12 +467,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -407,12 +467,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
previous_hidden_states
.
update
(
self
.
previous_hidden_states
.
update
(
execute_model_req
.
seq_group_metadata_list
,
hidden_states
)
execute_model_req
.
seq_group_metadata_list
,
hidden_states
)
sampler_output_to_return
=
(
self
.
_serialize_sampler_output_no_logprobs
(
execute_model_req
=
execute_model_req
,
sampler_output
=
sampler_output
)
if
self
.
_disable_logprobs
else
sampler_output
)
# Clear device tensors from sampler output. This reduces communication
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
# overhead when the engine runs in a different process than the workers.
sampler_output
.
probs
=
None
sampler_output
.
sampled_token_
probs
=
None
sampler_output
.
sampled_tokens
=
None
sampler_output
.
sampled_token
_id
s
=
None
sampler_output
.
logprobs
=
None
sampler_output
.
logprobs
=
None
return
[
sampler_output
]
return
[
sampler_output
_to_return
]
def
_run_non_driver_rank
(
self
)
->
bool
:
def
_run_non_driver_rank
(
self
)
->
bool
:
"""Run proposer and verifier model in non-driver workers. This is used
"""Run proposer and verifier model in non-driver workers. This is used
...
@@ -461,11 +526,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -461,11 +526,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
execute_model_req
,
self
.
_seq_with_bonus_token_in_last_step
)
execute_model_req
,
self
.
_seq_with_bonus_token_in_last_step
)
if
not
self
.
_allow_zero_draft_token_step
and
proposals
.
no_proposals
:
#TODO: Fix it #5814
raise
RuntimeError
(
"Cannot handle cases where distributed draft "
"workers generate no tokens"
)
proposal_scores
=
self
.
scorer
.
score_proposals
(
proposal_scores
=
self
.
scorer
.
score_proposals
(
execute_model_req
,
execute_model_req
,
proposals
,
proposals
,
)
)
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
proposals
,
execute_model_req
.
num_lookahead_slots
)
proposals
,
execute_model_req
.
num_lookahead_slots
)
...
@@ -521,11 +590,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -521,11 +590,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get proposed tokens.
# Get proposed tokens.
proposal_token_ids
=
proposals
.
proposal_token_ids
[
spec_indices
]
proposal_token_ids
=
proposals
.
proposal_token_ids
[
spec_indices
]
# Sampler arguments
sampler_extra_kwargs
=
{}
if
isinstance
(
self
.
spec_decode_sampler
,
SpecDecodeStochasticBaseSampler
):
# Get sequence group state
generators
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
if
(
seq_group_metadata
.
state
is
not
None
and
seq_group_metadata
.
state
.
generator
is
not
None
):
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
else
:
generators
.
append
(
None
)
sampler_extra_kwargs
[
"generators"
]
=
generators
accepted_token_ids
=
self
.
spec_decode_sampler
(
accepted_token_ids
=
self
.
spec_decode_sampler
(
target_probs
=
proposal_verifier_probs
,
target_probs
=
proposal_verifier_probs
,
bonus_token_ids
=
bonus_token_ids
,
bonus_token_ids
=
bonus_token_ids
,
draft_probs
=
proposal_probs
,
draft_probs
=
proposal_probs
,
draft_token_ids
=
proposal_token_ids
,
draft_token_ids
=
proposal_token_ids
,
**
sampler_extra_kwargs
,
)
)
# Append output tokens from non-speculative sequences to
# Append output tokens from non-speculative sequences to
...
@@ -569,25 +655,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -569,25 +655,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
the same number of outputs.
the same number of outputs.
"""
"""
batch_size
,
num_steps
=
accepted_token_ids
.
shape
batch_size
,
num_steps
=
accepted_token_ids
.
shape
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step
=
target_logprobs
.
transpose
(
0
,
1
)
accepted_token_ids_by_step
=
accepted_token_ids
.
transpose
(
0
,
1
)
accepted_token_ids_by_step
=
accepted_token_ids
.
transpose
(
0
,
1
)
if
self
.
_disable_logprobs
:
# Get the logprobs/rank of the accepted tokens.
# We are skipping the logprobs. Hence don't serialize the
(
accepted_token_id_ranks_by_step
,
# logprobs related tensors from the GPU. Instead create
accepted_token_id_logprobs_by_step
)
=
get_sampled_token_logprobs
(
# empty/dummy lists.
logprob_tensor
=
target_logprobs_by_step
,
(
accepted_token_id_ranks_by_step
,
sampled_token_ids
=
accepted_token_ids_by_step
,
accepted_token_id_logprobs_by_step
,
)
topk_logprobs_by_step
,
topk_indices_by_step
)
=
\
self
.
_create_dummy_logprob_lists
(
# Get the top-k logprobs (which may or may not include the logprob of
batch_size
,
num_steps
,
# the accepted token).
self
.
scorer_worker
.
model_config
.
max_logprobs
)
(
topk_logprobs_by_step
,
else
:
topk_indices_by_step
)
=
target_logprobs_by_step
.
topk
(
# Organize input tensors by step instead of by sequence.
k
=
self
.
scorer_worker
.
model_config
.
max_logprobs
,
target_logprobs_by_step
=
target_logprobs
.
transpose
(
0
,
1
)
dim
=-
1
,
# Serialize all tensors into Python lists.
)
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
=
\
self
.
_create_logprob_lists_from_tensors
(
target_logprobs_by_step
,
accepted_token_ids_by_step
,
self
.
scorer_worker
.
model_config
.
max_logprobs
)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
# batch.
...
@@ -596,14 +684,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -596,14 +684,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
num_logprobs_per_seq
=
get_all_num_logprobs
(
seq_group_metadata_list
)
num_logprobs_per_seq
=
get_all_num_logprobs
(
seq_group_metadata_list
)
# Serialize
all
tensor
s
to CPU Python list
s
.
# Serialize tensor to CPU Python list.
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
accepted_token_id_ranks_by_step
=
(
accepted_token_id_ranks_by_step
.
tolist
())
accepted_token_id_logprobs_by_step
=
(
accepted_token_id_logprobs_by_step
.
tolist
())
topk_logprobs_by_step
=
topk_logprobs_by_step
.
tolist
()
topk_indices_by_step
=
topk_indices_by_step
.
tolist
()
# Construct the output on a per-step, per-sequence basis.
# Construct the output on a per-step, per-sequence basis.
sampler_output_list
:
List
[
SamplerOutput
]
=
[]
sampler_output_list
:
List
[
SamplerOutput
]
=
[]
...
@@ -645,6 +727,108 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -645,6 +727,108 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
0
].
spec_decode_worker_metrics
=
maybe_rejsample_metrics
0
].
spec_decode_worker_metrics
=
maybe_rejsample_metrics
return
sampler_output_list
return
sampler_output_list
def
_create_dummy_logprob_lists
(
self
,
batch_size
:
int
,
num_steps
:
int
,
num_top_k
:
int
,
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
float
]],
List
[
List
[
List
[
Optional
[
float
]]]],
List
[
List
[
List
[
Optional
[
int
]]]]]:
"""
Creates and returns four dummy lists representing token probabilities
and their ranks.
This method initializes and returns:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
batch_size (int): The size of the batch.
num_steps (int): The number of steps in the sequence.
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing four dummy lists as described above.
"""
accepted_token_id_ranks_by_step
=
[[
-
1
]
*
batch_size
for
_
in
range
(
num_steps
)]
accepted_token_id_logprobs_by_step
=
[[
0.0
]
*
batch_size
for
_
in
range
(
num_steps
)]
topk_logprobs_by_step
:
List
[
List
[
List
[
Optional
[
float
]]]]
=
[[
[
None
]
*
num_top_k
for
_
in
range
(
batch_size
)
]
for
_
in
range
(
num_steps
)]
topk_indices_by_step
:
List
[
List
[
List
[
Optional
[
int
]]]]
=
[[
[
None
]
*
num_top_k
for
_
in
range
(
batch_size
)
]
for
_
in
range
(
num_steps
)]
return
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
def
_create_logprob_lists_from_tensors
(
self
,
target_logprobs_by_step
:
torch
.
Tensor
,
accepted_token_ids_by_step
:
torch
.
Tensor
,
num_top_k
:
int
,
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
float
]],
List
[
List
[
List
[
Optional
[
float
]]]],
List
[
List
[
List
[
Optional
[
int
]]]]]:
"""
Creates and returns four lists representing token probabilities and
their ranks.
This method initializes and returns four lists containing:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
target_logprobs_by_step (torch.Tensor): Tensor representing the
log probabilities of the target model,
shaped (num_steps, batch_size, vocab_size)
accepted_token_ids_by_step (torch.Tensor): Tensor representing
the accepted token_ids, shaped (num_steps, batch_size)
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing the lists as described above.
"""
# Serialize all tensors to CPU Python lists.
# Get the logprobs/rank of the accepted tokens.
(
accepted_token_id_ranks_by_step_tensor
,
accepted_token_id_logprobs_by_step_tensor
)
=
get_sampled_token_logprobs
(
logprob_tensor
=
target_logprobs_by_step
,
sampled_token_ids
=
accepted_token_ids_by_step
,
)
# Get the top-k logprobs (which may or may not include the
# logprob of the accepted token).
(
topk_logprobs_by_step_tensor
,
topk_indices_by_step_tensor
)
=
target_logprobs_by_step
.
topk
(
k
=
num_top_k
,
dim
=-
1
,
)
accepted_token_id_ranks_by_step
=
(
accepted_token_id_ranks_by_step_tensor
.
tolist
())
accepted_token_id_logprobs_by_step
=
(
accepted_token_id_logprobs_by_step_tensor
.
tolist
())
topk_logprobs_by_step
=
topk_logprobs_by_step_tensor
.
tolist
()
topk_indices_by_step
=
topk_indices_by_step_tensor
.
tolist
()
return
(
accepted_token_id_ranks_by_step
,
accepted_token_id_logprobs_by_step
,
topk_logprobs_by_step
,
topk_indices_by_step
)
def
_track_finished_requests
(
self
,
execute_model_req
:
ExecuteModelRequest
):
def
_track_finished_requests
(
self
,
execute_model_req
:
ExecuteModelRequest
):
"""
"""
Removes the finished requests and their associated sequence ids from
Removes the finished requests and their associated sequence ids from
...
...
vllm/spec_decode/target_model_runner.py
0 → 100644
View file @
500b93c8
from
typing
import
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
ModelRunner
)
class
TargetModelRunner
(
ModelRunner
):
"""Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means
that the time spent in the log probability calculation of the target model
is time wasted, since we calculate log probabilities after deciding which
tokens are accepted. For this reason disabling log probabilities in the
target model will make decode faster. The model runner sets the
SamplingMetadata parameters according to whether log probabilities are
requested or not.
"""
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
return_hidden_states
:
bool
=
False
):
# An internal boolean member variable to indicate if token log
# probabilities are needed or not.
self
.
disable_logprobs
=
True
super
().
__init__
(
model_config
=
model_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
device_config
=
device_config
,
cache_config
=
cache_config
,
load_config
=
load_config
,
lora_config
=
lora_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
,
multimodal_config
=
multimodal_config
,
prompt_adapter_config
=
prompt_adapter_config
,
return_hidden_states
=
return_hidden_states
,
)
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForGPUWithSamplingMetadata
:
model_input
:
ModelInputForGPUWithSamplingMetadata
=
super
(
).
prepare_model_input
(
seq_group_metadata_list
,
virtual_engine
,
finished_requests_ids
)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
# sampling related tensors which includes the logprobs tensors.
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
=
(
self
.
disable_logprobs
)
return
model_input
vllm/spec_decode/top1_proposer.py
View file @
500b93c8
...
@@ -108,7 +108,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -108,7 +108,7 @@ class Top1Proposer(SpeculativeProposer):
proposal_token_ids
=
proposal_tokens
,
proposal_token_ids
=
proposal_tokens
,
proposal_probs
=
proposal_probs
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
,
proposal_lens
=
proposal_lens
,
)
no_proposals
=
maybe_sampler_output
is
None
)
return
proposals
return
proposals
...
@@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
# Currently only proposal lens of 0 or the global batch proposal len
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# are supported.
# If max_proposal_len is defined, then we shall no exc
cess
this
# If max_proposal_len is defined, then we shall no exc
eed
this
# quota for nonzero_proposal
# quota for nonzero_proposal
new_k
=
0
new_k
=
0
if
(
self
.
max_proposal_len
is
None
if
(
self
.
max_proposal_len
is
None
...
@@ -219,7 +219,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -219,7 +219,7 @@ class Top1Proposer(SpeculativeProposer):
proposal_lens
:
List
[
int
],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
sampler_transposed
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
t
ensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
T
ensor
,
torch
.
Tensor
]:
"""After speculations are produced, merge the speculation results with
"""After speculations are produced, merge the speculation results with
the skipped sequences.
the skipped sequences.
"""
"""
...
...
Prev
1
…
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment