Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
19bc93d9
Commit
19bc93d9
authored
Oct 24, 2024
by
王敏
Browse files
增加medusa并行解码功能,后续增加使用说明和测试文档
parent
aba40fda
Changes
45
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1537 additions
and
143 deletions
+1537
-143
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+55
-2
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+14
-0
vllm/config.py
vllm/config.py
+24
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+28
-1
vllm/lora/layers.py
vllm/lora/layers.py
+114
-63
vllm/lora/models.py
vllm/lora/models.py
+7
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+18
-3
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+132
-7
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+1
-1
vllm/model_executor/models/medusa.py
vllm/model_executor/models/medusa.py
+82
-7
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+14
-4
vllm/sequence.py
vllm/sequence.py
+70
-1
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+189
-0
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+16
-0
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+179
-10
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+216
-25
vllm/spec_decode/tree_style_proposer.py
vllm/spec_decode/tree_style_proposer.py
+318
-0
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+44
-14
vllm/transformers_utils/configs/medusa.py
vllm/transformers_utils/configs/medusa.py
+11
-1
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+5
-0
No files found.
vllm/attention/backends/xformers.py
View file @
19bc93d9
...
@@ -16,6 +16,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
...
@@ -16,6 +16,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm
import
_custom_ops
as
ops
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -67,6 +68,50 @@ class XFormersBackend(AttentionBackend):
...
@@ -67,6 +68,50 @@ class XFormersBackend(AttentionBackend):
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
key_caches
=
[]
value_caches
=
[]
num_layers
=
len
(
kv_caches
)
token_num
=
src_to_dists
.
shape
[
0
]
tmp_store_kv
=
torch
.
empty
(
(
2
,
num_layers
,
token_num
,
num_kv_heads
,
head_size
),
dtype
=
kv_caches
[
0
].
dtype
,
device
=
kv_caches
[
0
].
device
)
keys
=
tmp_store_kv
[
0
].
contiguous
()
values
=
tmp_store_kv
[
1
].
contiguous
()
for
kv_cache
in
kv_caches
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
num_kv_heads
,
head_size
)
key_caches
.
append
(
key_cache
)
value_caches
.
append
(
value_cache
)
ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
0
].
contiguous
(),
kv_cache_dtype
)
ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
1
].
contiguous
(),
kv_cache_dtype
)
@
dataclass
@
dataclass
class
XFormersMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
class
XFormersMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
...
@@ -144,6 +189,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -144,6 +189,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# It is a list because it is needed to set per prompt
...
@@ -223,7 +270,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -223,7 +270,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
cross_block_tables
=
self
.
cross_block_tables
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
)
return
self
.
_cached_prefill_metadata
return
self
.
_cached_prefill_metadata
@
property
@
property
...
@@ -262,7 +310,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -262,7 +310,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
cross_block_tables
=
self
.
cross_block_tables
,
tree_attention_masks_tensor
=
self
.
tree_attention_masks_tensor
)
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
...
@@ -633,6 +682,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -633,6 +682,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
block_tables_arg
,
block_tables_arg
,
)
=
_get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
)
=
_get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
tree_attention_masks_tensor
=
decode_meta
.
tree_attention_masks_tensor
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
decode_query
,
decode_query
,
key_cache
,
key_cache
,
...
@@ -646,6 +697,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -646,6 +697,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self
.
alibi_slopes
,
self
.
alibi_slopes
,
k_scale
,
k_scale
,
v_scale
,
v_scale
,
attn_masks
=
tree_attention_masks_tensor
,
attn_masks_stride
=
tree_attention_masks_tensor
.
stride
(
0
)
if
tree_attention_masks_tensor
is
not
None
else
0
)
)
# Reshape the output tensor.
# Reshape the output tensor.
...
...
vllm/attention/ops/paged_attn.py
View file @
19bc93d9
...
@@ -103,6 +103,8 @@ class PagedAttention:
...
@@ -103,6 +103,8 @@ class PagedAttention:
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
blocksparse_head_sliding_step
:
int
=
0
,
attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_masks_stride
:
int
=
0
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
blocksparse_vert_stride
is
not
None
and
blocksparse_vert_stride
>
1
:
if
blocksparse_vert_stride
is
not
None
and
blocksparse_vert_stride
>
1
:
# use blocksparse paged attention
# use blocksparse paged attention
...
@@ -160,6 +162,8 @@ class PagedAttention:
...
@@ -160,6 +162,8 @@ class PagedAttention:
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
)
else
:
else
:
ops
.
paged_attention_v1_opt
(
ops
.
paged_attention_v1_opt
(
...
@@ -182,6 +186,8 @@ class PagedAttention:
...
@@ -182,6 +186,8 @@ class PagedAttention:
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
)
else
:
else
:
ops
.
paged_attention_v1
(
ops
.
paged_attention_v1
(
...
@@ -204,6 +210,8 @@ class PagedAttention:
...
@@ -204,6 +210,8 @@ class PagedAttention:
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
)
else
:
else
:
# Run PagedAttention V2.
# Run PagedAttention V2.
...
@@ -251,6 +259,8 @@ class PagedAttention:
...
@@ -251,6 +259,8 @@ class PagedAttention:
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
)
else
:
else
:
ops
.
paged_attention_v2_opt
(
ops
.
paged_attention_v2_opt
(
...
@@ -276,6 +286,8 @@ class PagedAttention:
...
@@ -276,6 +286,8 @@ class PagedAttention:
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
)
else
:
else
:
ops
.
paged_attention_v2
(
ops
.
paged_attention_v2
(
...
@@ -301,6 +313,8 @@ class PagedAttention:
...
@@ -301,6 +313,8 @@ class PagedAttention:
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
)
return
output
return
output
...
...
vllm/config.py
View file @
19bc93d9
...
@@ -1129,6 +1129,8 @@ class SpeculativeConfig:
...
@@ -1129,6 +1129,8 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold
:
Optional
[
float
],
typical_acceptance_sampler_posterior_threshold
:
Optional
[
float
],
typical_acceptance_sampler_posterior_alpha
:
Optional
[
float
],
typical_acceptance_sampler_posterior_alpha
:
Optional
[
float
],
disable_logprobs
:
Optional
[
bool
],
disable_logprobs
:
Optional
[
bool
],
num_speculative_heads
:
Optional
[
int
],
tree_style_spec_decoding
:
Optional
[
bool
]
=
None
,
)
->
Optional
[
"SpeculativeConfig"
]:
)
->
Optional
[
"SpeculativeConfig"
]:
"""Create a SpeculativeConfig if possible, else return None.
"""Create a SpeculativeConfig if possible, else return None.
...
@@ -1186,6 +1188,12 @@ class SpeculativeConfig:
...
@@ -1186,6 +1188,12 @@ class SpeculativeConfig:
If set to False, token log probabilities are returned
If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams.
according to the log probability settings in SamplingParams.
If not specified, it defaults to True.
If not specified, it defaults to True.
num_speculative_heads (Optional[int]): It will be used in tree-style
speculative generation, representing how many heads the draft model
has.
tree_style_spec_decoding (Optional[bool]): If set to True,
tree-style generation will be activated. If not specified,
it defaults to False.
Returns:
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
...
@@ -1264,6 +1272,10 @@ class SpeculativeConfig:
...
@@ -1264,6 +1272,10 @@ class SpeculativeConfig:
and
hasattr
(
draft_hf_config
,
"num_lookahead_tokens"
)):
and
hasattr
(
draft_hf_config
,
"num_lookahead_tokens"
)):
draft_hf_config
.
num_lookahead_tokens
=
num_speculative_tokens
draft_hf_config
.
num_lookahead_tokens
=
num_speculative_tokens
if
(
num_speculative_heads
is
not
None
and
hasattr
(
draft_hf_config
,
"num_lookahead_heads"
)):
draft_hf_config
.
num_lookahead_heads
=
num_speculative_heads
n_predict
=
getattr
(
draft_hf_config
,
"n_predict"
,
None
)
n_predict
=
getattr
(
draft_hf_config
,
"n_predict"
,
None
)
if
n_predict
is
not
None
:
if
n_predict
is
not
None
:
if
num_speculative_tokens
is
None
:
if
num_speculative_tokens
is
None
:
...
@@ -1296,9 +1308,9 @@ class SpeculativeConfig:
...
@@ -1296,9 +1308,9 @@ class SpeculativeConfig:
"n_predict parameter."
)
"n_predict parameter."
)
if
typical_acceptance_sampler_posterior_threshold
is
None
:
if
typical_acceptance_sampler_posterior_threshold
is
None
:
typical_acceptance_sampler_posterior_threshold
=
0.
09
typical_acceptance_sampler_posterior_threshold
=
0.
3
if
typical_acceptance_sampler_posterior_alpha
is
None
:
if
typical_acceptance_sampler_posterior_alpha
is
None
:
typical_acceptance_sampler_posterior_alpha
=
0.
3
typical_acceptance_sampler_posterior_alpha
=
0.
09
if
disable_logprobs
is
None
:
if
disable_logprobs
is
None
:
disable_logprobs
=
True
disable_logprobs
=
True
...
@@ -1316,6 +1328,7 @@ class SpeculativeConfig:
...
@@ -1316,6 +1328,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha
,
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
disable_logprobs
,
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
disable_log_stats
=
disable_log_stats
,
tree_style_spec_decoding
=
tree_style_spec_decoding
)
)
@
staticmethod
@
staticmethod
...
@@ -1410,6 +1423,7 @@ class SpeculativeConfig:
...
@@ -1410,6 +1423,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
disable_log_stats
:
bool
,
tree_style_spec_decoding
:
bool
,
):
):
"""Create a SpeculativeConfig object.
"""Create a SpeculativeConfig object.
...
@@ -1444,6 +1458,7 @@ class SpeculativeConfig:
...
@@ -1444,6 +1458,7 @@ class SpeculativeConfig:
returned.
returned.
disable_log_stats: Whether to disable periodic printing of stage
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
times in speculative decoding.
tree_style_spec_decoding: Whether to use tree-style generation.
"""
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
draft_parallel_config
=
draft_parallel_config
...
@@ -1459,6 +1474,7 @@ class SpeculativeConfig:
...
@@ -1459,6 +1474,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha
typical_acceptance_sampler_posterior_alpha
self
.
disable_logprobs
=
disable_logprobs
self
.
disable_logprobs
=
disable_logprobs
self
.
disable_log_stats
=
disable_log_stats
self
.
disable_log_stats
=
disable_log_stats
self
.
tree_style_spec_decoding
=
tree_style_spec_decoding
self
.
_verify_args
()
self
.
_verify_args
()
...
@@ -1526,6 +1542,8 @@ class LoRAConfig:
...
@@ -1526,6 +1542,8 @@ class LoRAConfig:
# This is a constant.
# This is a constant.
lora_vocab_padding_size
:
ClassVar
[
int
]
=
256
lora_vocab_padding_size
:
ClassVar
[
int
]
=
256
long_lora_scaling_factors
:
Optional
[
Tuple
[
float
]]
=
None
long_lora_scaling_factors
:
Optional
[
Tuple
[
float
]]
=
None
merge_lora
:
bool
=
False
lora_target_modules
:
Optional
[
List
[
str
]]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Setting the maximum rank to 256 should be able to satisfy the vast
# Setting the maximum rank to 256 should be able to satisfy the vast
...
@@ -1548,6 +1566,10 @@ class LoRAConfig:
...
@@ -1548,6 +1566,10 @@ class LoRAConfig:
raise
ValueError
(
raise
ValueError
(
f
"max_cpu_loras (
{
self
.
max_cpu_loras
}
) must be >= "
f
"max_cpu_loras (
{
self
.
max_cpu_loras
}
) must be >= "
f
"max_loras (
{
self
.
max_loras
}
)"
)
f
"max_loras (
{
self
.
max_loras
}
)"
)
if
self
.
merge_lora
and
self
.
max_loras
>
1
:
raise
ValueError
(
f
"merge_lora (
{
self
.
merge_lora
}
) can only be used when "
f
"max_loras (
{
self
.
max_loras
}
) is 1"
)
def
verify_with_model_config
(
self
,
model_config
:
ModelConfig
):
def
verify_with_model_config
(
self
,
model_config
:
ModelConfig
):
if
self
.
lora_dtype
in
(
None
,
"auto"
):
if
self
.
lora_dtype
in
(
None
,
"auto"
):
...
...
vllm/engine/arg_utils.py
View file @
19bc93d9
...
@@ -143,6 +143,8 @@ class EngineArgs:
...
@@ -143,6 +143,8 @@ class EngineArgs:
long_lora_scaling_factors
:
Optional
[
Tuple
[
float
]]
=
None
long_lora_scaling_factors
:
Optional
[
Tuple
[
float
]]
=
None
lora_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
'auto'
lora_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
'auto'
max_cpu_loras
:
Optional
[
int
]
=
None
max_cpu_loras
:
Optional
[
int
]
=
None
merge_lora
:
bool
=
False
lora_target_modules
:
Optional
[
List
[
str
]]
=
None
device
:
str
=
'auto'
device
:
str
=
'auto'
num_scheduler_steps
:
int
=
1
num_scheduler_steps
:
int
=
1
multi_step_stream_outputs
:
bool
=
False
multi_step_stream_outputs
:
bool
=
False
...
@@ -162,6 +164,7 @@ class EngineArgs:
...
@@ -162,6 +164,7 @@ class EngineArgs:
speculative_model_quantization
:
Optional
[
str
]
=
None
speculative_model_quantization
:
Optional
[
str
]
=
None
speculative_draft_tensor_parallel_size
:
Optional
[
int
]
=
None
speculative_draft_tensor_parallel_size
:
Optional
[
int
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
num_speculative_heads
:
Optional
[
int
]
=
None
speculative_max_model_len
:
Optional
[
int
]
=
None
speculative_max_model_len
:
Optional
[
int
]
=
None
speculative_disable_by_batch_size
:
Optional
[
int
]
=
None
speculative_disable_by_batch_size
:
Optional
[
int
]
=
None
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
...
@@ -173,6 +176,7 @@ class EngineArgs:
...
@@ -173,6 +176,7 @@ class EngineArgs:
disable_logprobs_during_spec_decoding
:
Optional
[
bool
]
=
None
disable_logprobs_during_spec_decoding
:
Optional
[
bool
]
=
None
otlp_traces_endpoint
:
Optional
[
str
]
=
None
otlp_traces_endpoint
:
Optional
[
str
]
=
None
tree_style_spec_decoding
:
Optional
[
bool
]
=
None
collect_detailed_traces
:
Optional
[
str
]
=
None
collect_detailed_traces
:
Optional
[
str
]
=
None
disable_async_output_proc
:
bool
=
False
disable_async_output_proc
:
bool
=
False
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
...
@@ -534,6 +538,14 @@ class EngineArgs:
...
@@ -534,6 +538,14 @@ class EngineArgs:
type
=
int
,
type
=
int
,
default
=
EngineArgs
.
max_lora_rank
,
default
=
EngineArgs
.
max_lora_rank
,
help
=
'Max LoRA rank.'
)
help
=
'Max LoRA rank.'
)
parser
.
add_argument
(
'--merge-lora'
,
type
=
bool
,
default
=
False
,
help
=
'If set to True, the weights of the base layer will be merged with the weights of Lora.'
)
parser
.
add_argument
(
'--lora-target-modules'
,
nargs
=
'*'
,
default
=
None
,
help
=
'List of lora module name, If not specified, modules will be chosen according to the model architecture.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--lora-extra-vocab-size'
,
'--lora-extra-vocab-size'
,
type
=
int
,
type
=
int
,
...
@@ -639,6 +651,12 @@ class EngineArgs:
...
@@ -639,6 +651,12 @@ class EngineArgs:
default
=
EngineArgs
.
num_speculative_tokens
,
default
=
EngineArgs
.
num_speculative_tokens
,
help
=
'The number of speculative tokens to sample from '
help
=
'The number of speculative tokens to sample from '
'the draft model in speculative decoding.'
)
'the draft model in speculative decoding.'
)
parser
.
add_argument
(
'--num-speculative-heads'
,
type
=
int
,
default
=
EngineArgs
.
num_speculative_heads
,
help
=
'The number of speculative heads to sample from '
'the draft model in speculative decoding.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--speculative-draft-tensor-parallel-size'
,
'--speculative-draft-tensor-parallel-size'
,
'-spec-draft-tp'
,
'-spec-draft-tp'
,
...
@@ -690,6 +708,11 @@ class EngineArgs:
...
@@ -690,6 +708,11 @@ class EngineArgs:
'a higher acceptance rate at the cost of lower quality, '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.'
)
'and vice versa.'
)
parser
.
add_argument
(
'--tree-style-spec-decoding'
,
type
=
bool
,
default
=
False
,
help
=
'If set to True, tree-style generation will be activated.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--typical-acceptance-sampler-posterior-threshold'
,
'--typical-acceptance-sampler-posterior-threshold'
,
type
=
float
,
type
=
float
,
...
@@ -974,6 +997,8 @@ class EngineArgs:
...
@@ -974,6 +997,8 @@ class EngineArgs:
typical_acceptance_sampler_posterior_alpha
=
self
.
typical_acceptance_sampler_posterior_alpha
=
self
.
typical_acceptance_sampler_posterior_alpha
,
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
self
.
disable_logprobs_during_spec_decoding
,
disable_logprobs
=
self
.
disable_logprobs_during_spec_decoding
,
tree_style_spec_decoding
=
self
.
tree_style_spec_decoding
,
num_speculative_heads
=
self
.
num_speculative_heads
)
)
if
self
.
num_scheduler_steps
>
1
:
if
self
.
num_scheduler_steps
>
1
:
...
@@ -1016,7 +1041,9 @@ class EngineArgs:
...
@@ -1016,7 +1041,9 @@ class EngineArgs:
long_lora_scaling_factors
=
self
.
long_lora_scaling_factors
,
long_lora_scaling_factors
=
self
.
long_lora_scaling_factors
,
lora_dtype
=
self
.
lora_dtype
,
lora_dtype
=
self
.
lora_dtype
,
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
and
self
.
max_cpu_loras
>
0
else
None
)
if
self
.
enable_lora
else
None
and
self
.
max_cpu_loras
>
0
else
None
,
merge_lora
=
self
.
merge_lora
,
lora_target_modules
=
self
.
lora_target_modules
)
if
self
.
enable_lora
else
None
if
self
.
qlora_adapter_name_or_path
is
not
None
and
\
if
self
.
qlora_adapter_name_or_path
is
not
None
and
\
self
.
qlora_adapter_name_or_path
!=
""
:
self
.
qlora_adapter_name_or_path
!=
""
:
...
...
vllm/lora/layers.py
View file @
19bc93d9
...
@@ -131,6 +131,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -131,6 +131,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self
.
base_layer
=
base_layer
self
.
base_layer
=
base_layer
self
.
embeddings_slice
:
Optional
[
Tuple
[
int
,
int
]]
self
.
embeddings_slice
:
Optional
[
Tuple
[
int
,
int
]]
self
.
embeddings_weights
:
Optional
[
torch
.
Tensor
]
self
.
embeddings_weights
:
Optional
[
torch
.
Tensor
]
self
.
device
=
_get_lora_device
(
self
.
base_layer
)
def
create_lora_weights
(
def
create_lora_weights
(
self
,
self
,
...
@@ -207,6 +208,18 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -207,6 +208,18 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self
.
lora_b_stacked
[
index
,
self
.
lora_b_stacked
[
index
,
0
,
:
lora_b
.
shape
[
1
],
:
lora_b
.
shape
[
0
]].
copy_
(
0
,
:
lora_b
.
shape
[
1
],
:
lora_b
.
shape
[
0
]].
copy_
(
lora_b
.
T
,
non_blocking
=
True
)
lora_b
.
T
,
non_blocking
=
True
)
self
.
lora_a
=
lora_a
.
to
(
self
.
device
)
self
.
lora_b
=
lora_b
.
to
(
self
.
device
)
if
self
.
lora_config
.
merge_lora
:
merged_weights
=
torch
.
matmul
(
self
.
lora_a
,
self
.
lora_b
)
if
merged_weights
.
shape
!=
self
.
base_layer
.
weight
.
data
:
merged_weights
=
merged_weights
.
T
+
self
.
base_layer
.
weight
else
:
merged_weights
=
merged_weights
+
self
.
base_layer
.
weight
self
.
base_layer
.
weight
.
data
.
copy_
(
merged_weights
)
if
embeddings_tensor
is
not
None
:
if
embeddings_tensor
is
not
None
:
self
.
embeddings_tensors
[
self
.
embeddings_tensors
[
index
,
:
embeddings_tensor
.
shape
[
0
],
:
embeddings_tensor
.
index
,
:
embeddings_tensor
.
shape
[
0
],
:
embeddings_tensor
.
...
@@ -225,12 +238,15 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -225,12 +238,15 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
added_tokens_mask
=
x
>
self
.
base_layer
.
org_vocab_size
-
1
added_tokens_mask
=
x
>
self
.
base_layer
.
org_vocab_size
-
1
embeddings_indices
=
self
.
punica_wrapper
.
embeddings_indices
embeddings_indices
=
self
.
punica_wrapper
.
embeddings_indices
indices
=
embeddings_indices
[
1
].
view_as
(
x
)
indices
=
embeddings_indices
[
0
].
view_as
(
x
)
if
not
self
.
lora_config
.
merge_lora
:
indices_0
=
embeddings_indices
[
1
].
view_as
(
x
)
full_lora_a_embeddings
=
F
.
embedding
(
full_lora_a_embeddings
=
F
.
embedding
(
x
+
indices
,
x
+
indices
_0
,
self
.
lora_a_stacked_2d
,
self
.
lora_a_stacked_2d
,
)
)
indices
=
embeddings_indices
[
0
].
view_as
(
x
)
full_output
=
self
.
base_layer
.
forward
(
full_output
=
self
.
base_layer
.
forward
(
x
.
add_
(
indices
*
added_tokens_mask
))
x
.
add_
(
indices
*
added_tokens_mask
))
...
@@ -251,6 +267,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -251,6 +267,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self
.
lora_b_stacked
,
self
.
lora_b_stacked
,
add_input
=
True
)
add_input
=
True
)
return
full_output
.
view_as
(
full_output_org
)
return
full_output
.
view_as
(
full_output_org
)
else
:
full_output
=
self
.
base_layer
.
forward
(
x
.
add_
(
indices
*
added_tokens_mask
))
return
full_output
@
classmethod
@
classmethod
def
can_replace_layer
(
def
can_replace_layer
(
...
@@ -778,6 +798,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -778,6 +798,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
indices_len
:
List
[
int
]
self
.
indices_len
:
List
[
int
]
def
reset_lora
(
self
,
index
:
int
):
def
reset_lora
(
self
,
index
:
int
):
if
not
self
.
lora_config
.
merge_lora
:
self
.
lora_a_stacked
[
0
][
index
]
=
0
self
.
lora_a_stacked
[
0
][
index
]
=
0
self
.
lora_b_stacked
[
0
][
index
]
=
0
self
.
lora_b_stacked
[
0
][
index
]
=
0
self
.
lora_a_stacked
[
1
][
index
]
=
0
self
.
lora_a_stacked
[
1
][
index
]
=
0
...
@@ -822,6 +843,33 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -822,6 +843,33 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
if
self
.
lora_config
.
merge_lora
:
qkv_weights_list
=
[]
for
i
in
range
(
len
(
self
.
output_slices
)):
if
lora_a
[
i
]
is
not
None
:
if
lora_a
[
i
].
numel
()
==
0
or
lora_b
[
i
].
numel
()
==
0
:
continue
weight_A
=
lora_a
[
i
].
to
(
self
.
device
)
weight_B
=
lora_b
[
i
].
to
(
self
.
device
)
delta_weight
=
torch
.
matmul
(
weight_A
,
weight_B
)
qkv_weights_list
.
append
(
delta_weight
)
else
:
if
i
==
0
:
qkv_weights_list
.
append
(
torch
.
zeros
(
self
.
input_size
,
self
.
q_proj_shard_size
,
dtype
=
self
.
base_layer
.
weight
.
dtype
,
device
=
self
.
device
))
else
:
qkv_weights_list
.
append
(
torch
.
zeros
(
self
.
input_size
,
self
.
kv_proj_shard_size
,
dtype
=
self
.
base_layer
.
weight
.
dtype
,
device
=
self
.
device
))
if
len
(
qkv_weights_list
)
>
0
:
qkv_weights
=
torch
.
cat
(
qkv_weights_list
,
dim
=-
1
)
if
qkv_weights
.
shape
!=
self
.
base_layer
.
weight
.
shape
:
qkv_weights
=
qkv_weights
.
T
+
self
.
base_layer
.
weight
.
data
else
:
qkv_weights
=
qkv_weights
+
self
.
base_layer
.
weight
.
data
self
.
base_layer
.
weight
.
data
.
copy_
(
qkv_weights
)
else
:
if
lora_b
[
0
]
is
not
None
:
if
lora_b
[
0
]
is
not
None
:
lora_b_q
=
lora_b
[
0
]
lora_b_q
=
lora_b
[
0
]
self
.
lora_b_stacked
[
0
][
self
.
lora_b_stacked
[
0
][
...
@@ -853,11 +901,14 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -853,11 +901,14 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def
apply
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
if
not
self
.
lora_config
.
merge_lora
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
self
.
punica_wrapper
.
add_lora_packed_nslice
(
output
,
x
,
self
.
punica_wrapper
.
add_lora_packed_nslice
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
1.0
,
self
.
lora_b_stacked
,
1.0
,
self
.
output_slices
)
self
.
output_slices
)
else
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
return
output
return
output
@
classmethod
@
classmethod
...
...
vllm/lora/models.py
View file @
19bc93d9
...
@@ -117,6 +117,8 @@ class LoRAModel(AdapterModel):
...
@@ -117,6 +117,8 @@ class LoRAModel(AdapterModel):
pin_memory
=
str
(
device
)
==
"cpu"
and
is_pin_memory_available
()
pin_memory
=
str
(
device
)
==
"cpu"
and
is_pin_memory_available
()
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
{}
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
{}
for
tensor_name
,
tensor
in
tensors
.
items
():
for
tensor_name
,
tensor
in
tensors
.
items
():
if
"lora_A"
not
in
tensor_name
and
"lora_B"
not
in
tensor_name
:
continue
module_name
,
is_lora_a
=
parse_fine_tuned_lora_name
(
tensor_name
)
module_name
,
is_lora_a
=
parse_fine_tuned_lora_name
(
tensor_name
)
if
module_name
not
in
loras
:
if
module_name
not
in
loras
:
lora_embeddings_tensor
=
None
lora_embeddings_tensor
=
None
...
@@ -324,6 +326,9 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -324,6 +326,9 @@ class LoRAModelManager(AdapterModelManager):
self
.
scaling_factor_to_offset
:
Dict
[
float
,
int
]
=
{}
self
.
scaling_factor_to_offset
:
Dict
[
float
,
int
]
=
{}
super
().
__init__
(
model
)
super
().
__init__
(
model
)
if
hasattr
(
self
.
model
,
"supported_lora_modules"
):
if
hasattr
(
self
.
model
,
"supported_lora_modules"
):
if
lora_config
.
lora_target_modules
is
not
None
:
self
.
supported_lora_modules
=
lora_config
.
lora_target_modules
else
:
self
.
supported_lora_modules
=
copy
.
deepcopy
(
self
.
supported_lora_modules
=
copy
.
deepcopy
(
self
.
model
.
supported_lora_modules
)
self
.
model
.
supported_lora_modules
)
if
lora_config
.
long_lora_scaling_factors
:
if
lora_config
.
long_lora_scaling_factors
:
...
...
vllm/model_executor/layers/sampler.py
View file @
19bc93d9
...
@@ -117,6 +117,15 @@ class SamplerOutput(
...
@@ -117,6 +117,15 @@ class SamplerOutput(
# block/sync across workers, cpu-gpu sync time and sampling time.
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time
:
Optional
[
float
]
=
None
model_execute_time
:
Optional
[
float
]
=
None
# Optional lm_head logits from the model.
logits
:
Optional
[
torch
.
Tensor
]
=
None
# tree-style cartesian candidates
cart_candidates
:
Optional
[
torch
.
Tensor
]
=
None
# tree-style cartesian candidates
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
def
__getitem__
(
self
,
idx
:
int
):
def
__getitem__
(
self
,
idx
:
int
):
return
self
.
outputs
[
idx
]
return
self
.
outputs
[
idx
]
...
@@ -141,7 +150,9 @@ class SamplerOutput(
...
@@ -141,7 +150,9 @@ class SamplerOutput(
f
"SamplerOutput(outputs=
{
self
.
outputs
}
, "
f
"SamplerOutput(outputs=
{
self
.
outputs
}
, "
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
, "
f
"logits=
{
self
.
logits
}
, "
f
"tree_attn_masks=
{
self
.
tree_attn_masks
}
)"
)
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
...
@@ -224,6 +235,7 @@ class Sampler(nn.Module):
...
@@ -224,6 +235,7 @@ class Sampler(nn.Module):
sampling_metadata: Metadata for sampling.
sampling_metadata: Metadata for sampling.
"""
"""
assert
logits
is
not
None
assert
logits
is
not
None
original_logits
=
logits
.
clone
()
_
,
vocab_size
=
logits
.
shape
_
,
vocab_size
=
logits
.
shape
# Prepare sampling tensors with pinned memory to avoid blocking.
# Prepare sampling tensors with pinned memory to avoid blocking.
...
@@ -307,7 +319,8 @@ class Sampler(nn.Module):
...
@@ -307,7 +319,8 @@ class Sampler(nn.Module):
prompt_logprobs
,
prompt_logprobs
,
sample_logprobs
,
sample_logprobs
,
on_device_tensors
=
on_device_tensors
,
on_device_tensors
=
on_device_tensors
,
skip_sampler_cpu_output
=
sampling_metadata
.
skip_sampler_cpu_output
)
skip_sampler_cpu_output
=
sampling_metadata
.
skip_sampler_cpu_output
,
logits
=
original_logits
)
@
property
@
property
def
_should_modify_greedy_probs_inplace
(
self
)
->
bool
:
def
_should_modify_greedy_probs_inplace
(
self
)
->
bool
:
...
@@ -1237,6 +1250,7 @@ def _build_sampler_output(
...
@@ -1237,6 +1250,7 @@ def _build_sampler_output(
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
torch
.
Tensor
]],
skip_sampler_cpu_output
:
bool
=
False
,
skip_sampler_cpu_output
:
bool
=
False
,
logits
:
Optional
[
torch
.
Tensor
]
=
None
)
->
SamplerOutput
:
)
->
SamplerOutput
:
"""Construct Python objects with the output of sampling.
"""Construct Python objects with the output of sampling.
...
@@ -1287,7 +1301,8 @@ def _build_sampler_output(
...
@@ -1287,7 +1301,8 @@ def _build_sampler_output(
sampled_token_probs
=
sampled_token_probs
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
logprobs_tensor
,
logprobs
=
logprobs_tensor
,
deferred_sample_results_args
=
deferred_sample_results_args
)
deferred_sample_results_args
=
deferred_sample_results_args
,
logits
=
logits
)
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
int
]:
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
int
]:
...
...
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
19bc93d9
from
typing
import
Optional
,
List
import
torch
import
torch
import
torch.jit
import
torch.jit
import
torch.nn.functional
as
F
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeDeterministicBaseSampler
)
SpecDecodeDeterministicBaseSampler
)
...
@@ -40,6 +42,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -40,6 +42,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
cart_candidates
:
Optional
[
torch
.
Tensor
]
=
None
,
best_candidates
:
Optional
[
List
]
=
None
,
accept_lengths
:
Optional
[
List
]
=
None
,
first_step_flags
:
Optional
[
List
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Sample token ids using typical acceptance sampling. This accepts
"""Sample token ids using typical acceptance sampling. This accepts
or rejects tokens proposed by the draft model using the probability
or rejects tokens proposed by the draft model using the probability
...
@@ -66,6 +72,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -66,6 +72,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
probabilities.
probabilities.
shape = [batch_size, num_speculative_tokens]
shape = [batch_size, num_speculative_tokens]
cart_candidates: tree-style cartesian candidates
best_candidates: pending to write best candidates index
accept_lengths: pending to write accept lengths
first_step_flags: whether this is the first decoding step
Returns:
Returns:
output_token_ids: The token ids sampled via rejection sampling,
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
or -1 if unable to sample a token because the previous token
...
@@ -77,6 +88,8 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -77,6 +88,8 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
if
self
.
_strict_mode
:
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_input
(
target_with_bonus_probs
,
self
.
_raise_if_incorrect_input
(
target_with_bonus_probs
,
draft_token_ids
,
bonus_token_ids
)
draft_token_ids
,
bonus_token_ids
)
if
cart_candidates
is
None
:
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
draft_token_ids
)
draft_token_ids
)
...
@@ -84,8 +97,120 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -84,8 +97,120 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
draft_token_ids
,
bonus_token_ids
)
bonus_token_ids
)
else
:
target_probs
=
target_with_bonus_probs
output_token_ids
=
self
.
_evaluate_accepted_tokens_tree_style
(
target_probs
,
draft_token_ids
,
cart_candidates
,
best_candidates
,
accept_lengths
,
first_step_flags
)
return
output_token_ids
return
output_token_ids
def
_evaluate_accepted_tokens_tree_style
(
self
,
target_probs
,
draft_token_ids
,
cart_candidates
,
output_best_candidates
,
accept_lengths
,
first_step_flags
):
r
"""
Evaluates and returns a mask of accepted tokens based on the
posterior probabilities.
Parameters:
----------
target_probs : torch.Tensor
A tensor of shape (batch_size, k, vocab_size) representing
the probabilities of each token in the vocabulary for each
position in the proposed sequence. This is the distribution
generated by the target model.
draft_token_ids : torch.Tensor
A tensor of shape (batch_size, k) representing the proposed
token ids.
cart_candidates : torch.Tensor
A tensor of shape (batch_size, retrieve_size, tree_depth)
representing the cart candidates of tree proposals.
A draft token_id x_{n+k} is accepted if it satisfies the
following condition
.. math::
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
\min \left( \epsilon, \delta * \exp \left(
-H(p_{\text{original}}(
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
where :math:`p_{\text{original}}` corresponds to target_probs
and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters
specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given
draft token ids based on the provided target probabilities. It
calculates the entropy of the posterior distribution and determines
a dynamic threshold for each token position using the provided
posterior_threshold and posterior_alpha values. The method then
returns a boolean mask indicating which tokens can be accepted.
Returns:
-------
torch.Tensor
A boolean tensor of shape (batch_size, k) where each element
indicates whether the corresponding draft token has been accepted
or rejected. True indicates acceptance and false indicates
rejection.
"""
target_probs
=
target_probs
[:,
:,
:
-
1
]
device
=
target_probs
.
device
batch_size
=
cart_candidates
.
shape
[
0
]
candidates_prob
=
torch
.
gather
(
target_probs
,
dim
=-
1
,
index
=
cart_candidates
[:,
:,
1
:].
unsqueeze
(
-
1
)
).
squeeze
(
-
1
)
# [batch_size, retrieve_size, max_depth]
posterior_entropy
=
-
torch
.
sum
(
target_probs
*
torch
.
log
(
target_probs
+
1e-5
),
dim
=-
1
)
# torch.sum(torch.log(*)) is faster than torch.prod [batch_size, retrieve_size, max_depth]
threshold
=
torch
.
minimum
(
torch
.
ones_like
(
posterior_entropy
)
*
self
.
_posterior_threshold
,
torch
.
exp
(
-
posterior_entropy
)
*
self
.
_posterior_alpha
,
)
posterior_mask
=
candidates_prob
>
threshold
# [batch_size, retrieve_size, max_depth]
candidates_accept_length
=
(
torch
.
cumprod
(
posterior_mask
,
dim
=
2
)).
sum
(
dim
=-
1
)
# [batch_size, retrieve_size]
# Choose the best candidate based on the evaluated posterior probabilities
accept_length
,
_
=
candidates_accept_length
.
max
(
dim
=-
1
)
# [batch_size]
if
torch
.
any
(
accept_length
>
0
):
valid_index
=
(
candidates_accept_length
==
accept_length
.
unsqueeze
(
-
1
)).
unsqueeze
(
-
1
)
# [batch_size, retrieve_size, 1]
candidates_prob
=
candidates_prob
*
valid_index
# [batch_size, retrieve_size, max_depth]
valid_index
=
torch
.
arange
(
candidates_prob
.
shape
[
-
1
],
device
=
device
).
unsqueeze
(
0
).
unsqueeze
(
0
).
repeat
(
batch_size
,
candidates_prob
.
shape
[
1
],
1
)
# [batch_size, retrieve_size, max_depth]
valid_index
=
(
valid_index
<
accept_length
.
unsqueeze
(
1
).
unsqueeze
(
2
).
repeat
(
1
,
candidates_prob
.
shape
[
1
],
1
))
# [batch_size, retrieve_size, 1]
candidates_prob
=
candidates_prob
*
valid_index
# [batch_size, retrieve_size, max_depth]
# add 1e-3 to avoid zero value
likelihood
=
torch
.
sum
(
torch
.
log
(
candidates_prob
+
1e-3
),
dim
=-
1
)
# [batch_size, retrieve_size]
best_candidate
=
torch
.
argmax
(
likelihood
,
dim
=-
1
)
# [batch_size]
else
:
# Choose the best candidate
best_candidate
=
torch
.
zeros
((
batch_size
),
dtype
=
torch
.
long
,
device
=
device
)
# [batch_size]
k
=
draft_token_ids
.
shape
[
-
1
]
output_token_id_list
=
[]
print
(
"####################################accept_length:"
,
accept_length
)
for
i
in
range
(
batch_size
):
output_best_candidates
.
append
(
best_candidate
[
i
])
accept_lengths
.
append
(
accept_length
[
i
])
if
not
first_step_flags
[
i
]:
select_indices
=
cart_candidates
[
i
,
best_candidate
[
i
],
:
accept_length
[
i
]
+
1
]
select_indices
=
F
.
pad
(
select_indices
,
(
0
,
k
-
1
-
accept_length
[
i
]),
'constant'
,
-
1
)
else
:
select_indices
=
cart_candidates
[
i
,
best_candidate
[
i
],
1
:
accept_length
[
i
]
+
1
]
select_indices
=
F
.
pad
(
select_indices
,
(
0
,
k
-
accept_length
[
i
]),
'constant'
,
-
1
)
output_token_id_list
.
append
(
select_indices
)
return
torch
.
stack
(
output_token_id_list
,
dim
=
0
)
def
_evaluate_accepted_tokens
(
self
,
target_probs
,
draft_token_ids
):
def
_evaluate_accepted_tokens
(
self
,
target_probs
,
draft_token_ids
):
r
"""
r
"""
Evaluates and returns a mask of accepted tokens based on the
Evaluates and returns a mask of accepted tokens based on the
...
...
vllm/model_executor/model_loader/utils.py
View file @
19bc93d9
...
@@ -23,7 +23,7 @@ def get_model_architecture(
...
@@ -23,7 +23,7 @@ def get_model_architecture(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
visual
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
visual
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
]
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2ForCausalLM'
,
'ChatGLMModel'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
architectures
==
[
'QWenLMHeadModel'
]
and
visual
!=
[]:
if
architectures
==
[
'QWenLMHeadModel'
]
and
visual
!=
[]:
...
...
vllm/model_executor/models/medusa.py
View file @
19bc93d9
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
os
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Any
,
Dict
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -10,6 +11,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -10,6 +11,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.transformers_utils.configs.medusa
import
MedusaConfig
from
vllm.transformers_utils.configs.medusa
import
MedusaConfig
from
vllm
import
_custom_ops
as
ops
TOPK
=
10
# topk for sparse tree (10 is a placeholder and it is sufficient)
class
ResidualBlock
(
nn
.
Module
):
class
ResidualBlock
(
nn
.
Module
):
...
@@ -46,6 +50,8 @@ class Medusa(nn.Module):
...
@@ -46,6 +50,8 @@ class Medusa(nn.Module):
def
__init__
(
self
,
config
:
MedusaConfig
,
**
_
)
->
None
:
def
__init__
(
self
,
config
:
MedusaConfig
,
**
_
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
config
=
config
self
.
config
=
config
self
.
blocks
=
nn
.
ModuleList
([
self
.
blocks
=
nn
.
ModuleList
([
ResidualBlock
(
hidden_size
=
self
.
config
.
hidden_size
,
ResidualBlock
(
hidden_size
=
self
.
config
.
hidden_size
,
...
@@ -55,6 +61,7 @@ class Medusa(nn.Module):
...
@@ -55,6 +61,7 @@ class Medusa(nn.Module):
self
.
orig_vocab_size
=
config
.
vocab_size
self
.
orig_vocab_size
=
config
.
vocab_size
self
.
truncated_vocab_size
=
config
.
truncated_vocab_size
self
.
truncated_vocab_size
=
config
.
truncated_vocab_size
self
.
unpadded_vocab_size
=
self
.
truncated_vocab_size
self
.
unpadded_vocab_size
=
self
.
truncated_vocab_size
self
.
medusa_choices
=
config
.
medusa_choices
self
.
lm_heads
=
nn
.
ModuleList
([
self
.
lm_heads
=
nn
.
ModuleList
([
ParallelLMHead
(
ParallelLMHead
(
...
@@ -138,11 +145,62 @@ class Medusa(nn.Module):
...
@@ -138,11 +145,62 @@ class Medusa(nn.Module):
return
outputs
return
outputs
def
medusa_sample
(
self
,
medusa_logits
:
List
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
logits
:
torch
.
Tensor
,
medusa_buffers
:
Dict
[
str
,
Any
]
)
->
List
[
SamplerOutput
]:
batch_size
=
logits
.
shape
[
0
]
candidates_logit
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
batch_size
,
-
1
)
# [batch_size, 1]
medusa_logits
=
torch
.
stack
(
medusa_logits
,
dim
=
0
)
# [medusa_heads, batch_size, vocab_size]
# Extract the TOPK candidates from the medusa logits.
candidates_medusa_logits
=
torch
.
topk
(
medusa_logits
,
TOPK
,
dim
=-
1
).
indices
# [medusa_heads, batch_size, TOPK]
candidates_medusa_logits
=
candidates_medusa_logits
.
permute
(
1
,
0
,
2
)
# [batch_size, medusa_heads, TOPK]
candidates_medusa_logits
=
candidates_medusa_logits
.
reshape
(
batch_size
,
-
1
)
# Combine the selected candidate from the original logits with the topk medusa logits.
candidates
=
torch
.
cat
([
candidates_logit
,
candidates_medusa_logits
],
dim
=-
1
)
#[batch_size, 1+medusa_heads*TOPK]
# Map the combined candidates to the tree indices to get tree candidates.
tree_candidates
=
torch
.
index_select
(
candidates
,
dim
=-
1
,
index
=
medusa_buffers
[
'tree_indices'
])
# [batch_size, choices]
# Extend the tree candidates by appending a zero.
tree_candidates_ext
=
torch
.
cat
([
tree_candidates
,
torch
.
zeros
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
tree_candidates
.
device
)],
dim
=-
1
)
# [batch_size, choices]
# Retrieve the cartesian candidates using the retrieve indices.
cart_candidates
=
tree_candidates_ext
[:,
medusa_buffers
[
'retrieve_indices'
]]
# [batch_size, retrieve_size, max_depth]
token_id_list
=
[]
cart_candidate_list
=
[]
for
idx
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
token_id_list
.
append
(
tree_candidates
[
seq_group
.
sample_indices
,
:])
cart_candidate_list
.
append
(
cart_candidates
[
seq_group
.
sample_indices
,
:])
outputs
:
List
[
Optional
[
SamplerOutput
]]
=
[]
for
idx
in
range
(
len
(
sampling_metadata
.
seq_groups
)):
outputs
.
append
(
SamplerOutput
(
outputs
=
None
,
sampled_token_ids
=
token_id_list
[
idx
].
squeeze
(
1
),
cart_candidates
=
cart_candidate_list
[
idx
]
))
return
outputs
def
generate_proposals
(
def
generate_proposals
(
self
,
self
,
previous_hidden_states
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
previous_logits
:
torch
.
Tensor
=
None
,
medusa_buffers
:
Dict
[
str
,
Any
]
=
None
)
->
List
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
if
medusa_buffers
is
None
:
return
self
.
sample
(
return
self
.
sample
(
logits
=
self
.
compute_logits
(
logits
=
self
.
compute_logits
(
hidden_states
=
self
.
forward
(
previous_hidden_states
),
hidden_states
=
self
.
forward
(
previous_hidden_states
),
...
@@ -150,6 +208,14 @@ class Medusa(nn.Module):
...
@@ -150,6 +208,14 @@ class Medusa(nn.Module):
),
),
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
else
:
return
self
.
medusa_sample
(
medusa_logits
=
self
.
compute_logits
(
hidden_states
=
self
.
forward
(
previous_hidden_states
),
sampling_metadata
=
sampling_metadata
,
),
sampling_metadata
=
sampling_metadata
,
logits
=
previous_logits
,
medusa_buffers
=
medusa_buffers
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
...
@@ -177,6 +243,15 @@ class Medusa(nn.Module):
...
@@ -177,6 +243,15 @@ class Medusa(nn.Module):
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
use_llama_nn
and
"lm_head"
in
name
:
_weight
=
torch
.
zeros_like
(
param
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
param
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
param
.
data
.
copy_
(
_weight
)
param
.
data
=
param
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
if
self
.
token_map
is
not
None
:
if
self
.
token_map
is
not
None
:
self
.
token_map
.
to
(
device
=
self
.
lm_heads
[
0
].
weight
.
device
)
self
.
token_map
.
to
(
device
=
self
.
lm_heads
[
0
].
weight
.
device
)
...
...
vllm/model_executor/models/qwen2.py
View file @
19bc93d9
...
@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -329,6 +329,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -329,6 +329,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
"o_proj"
,
"o_proj"
,
"gate_up_proj"
,
"gate_up_proj"
,
"down_proj"
,
"down_proj"
,
"lm_head"
]
]
embedding_modules
=
{}
embedding_modules
=
{}
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
...
@@ -363,9 +364,18 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -363,9 +364,18 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
if
config
.
tie_word_embeddings
:
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
# self.lm_head = ParallelLMHead(config.vocab_size,
# config.hidden_size,
# quant_config=quant_config)
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
quant_config
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/sequence.py
View file @
19bc93d9
...
@@ -988,6 +988,9 @@ class SequenceGroupMetadata(
...
@@ -988,6 +988,9 @@ class SequenceGroupMetadata(
# TODO: We should maintain this states out of the sequence group.
# TODO: We should maintain this states out of the sequence group.
num_speculative_tokens
:
Optional
[
int
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
tree_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
seq_data
is
not
None
and
self
.
token_chunk_size
is
None
:
if
self
.
seq_data
is
not
None
and
self
.
token_chunk_size
is
None
:
if
self
.
is_prompt
:
if
self
.
is_prompt
:
...
@@ -1025,6 +1028,11 @@ class SequenceGroupMetadata(
...
@@ -1025,6 +1028,11 @@ class SequenceGroupMetadata(
assert
self
.
state
.
current_step
<
self
.
state
.
num_steps
assert
self
.
state
.
current_step
<
self
.
state
.
num_steps
self
.
state
.
current_step
+=
1
self
.
state
.
current_step
+=
1
def
set_tree_style_args
(
self
,
tree_attn_masks
:
Optional
[
torch
.
Tensor
],
tree_position_ids
:
Optional
[
torch
.
Tensor
]):
self
.
tree_attn_masks
=
tree_attn_masks
self
.
tree_position_ids
=
tree_position_ids
class
SequenceOutput
(
class
SequenceOutput
(
msgspec
.
Struct
,
msgspec
.
Struct
,
...
@@ -1271,6 +1279,56 @@ class HiddenStates(msgspec.Struct, array_like=True,
...
@@ -1271,6 +1279,56 @@ class HiddenStates(msgspec.Struct, array_like=True,
[
self
.
hidden_states
,
self
.
second_last_token_hidden_states
])[
index
]
[
self
.
hidden_states
,
self
.
second_last_token_hidden_states
])[
index
]
class
Logits
(
msgspec
.
Struct
,
array_like
=
True
,
omit_defaults
=
True
):
# type: ignore[call-arg]
"""Logits corresponding to in-progress sequences.
Used in speculative decoding to pass lm_head logits from
the target model to the proposer model in the subsequent step.
seq_ids are the sequence ids of each entry of the batch
dimension of the logits tensor"""
# Scorer hidden states. For prefill step, it is used for hidden states of
# all tokens, whereas for decode step, it use used for last accepted tokens.
logits
:
torch
.
Tensor
# The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
_seq_ids
:
List
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
def
__post_init__
(
self
):
if
self
.
seq_group_metadata_list
is
not
None
:
assert
len
(
self
.
seq_group_metadata_list
)
==
len
(
self
.
logits
)
self
.
_seq_ids
=
get_all_seq_ids
(
self
.
seq_group_metadata_list
)
@
property
def
seq_ids
(
self
)
->
List
[
int
]:
return
self
.
_seq_ids
def
update
(
self
,
logits
:
torch
.
Tensor
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]):
"""Update hidden states from target model invocation. Only used for
decode steps"""
assert
len
(
seq_group_metadata_list
)
==
len
(
logits
)
self
.
_seq_ids
.
extend
(
get_all_seq_ids
(
seq_group_metadata_list
))
self
.
logits
=
torch
.
cat
([
self
.
logits
,
logits
])
def
prune
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
None
:
"""Prune to provided list of sequence ids. Only used for decode steps.
"""
# Currently this prunes all seq_ids not present in
# seq_group_metadata_list which might cause problems where a sequence
# may be "paused" then "resumed" later. This should only prune sequences
# which are confirmed to be aborted.
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
if
seq_ids
!=
self
.
_seq_ids
:
# Batch contents changed - prune removed sequences.
index
=
[
self
.
_seq_ids
.
index
(
seq_id
)
for
seq_id
in
seq_ids
]
self
.
logits
=
self
.
logits
[
index
]
self
.
_seq_ids
=
seq_ids
class
ExecuteModelRequest
(
class
ExecuteModelRequest
(
msgspec
.
Struct
,
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
array_like
=
True
,
# type: ignore[call-arg]
...
@@ -1296,6 +1354,8 @@ class ExecuteModelRequest(
...
@@ -1296,6 +1354,8 @@ class ExecuteModelRequest(
running_queue_size
:
int
=
0
running_queue_size
:
int
=
0
# Optional hidden states from prior step.
# Optional hidden states from prior step.
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
# Optional logits from prior step.
previous_logits
:
Optional
[
Logits
]
=
None
# The number of forward steps to run.
# The number of forward steps to run.
num_steps
:
int
=
1
num_steps
:
int
=
1
# Finished request ids since last step.
# Finished request ids since last step.
...
@@ -1305,6 +1365,12 @@ class ExecuteModelRequest(
...
@@ -1305,6 +1365,12 @@ class ExecuteModelRequest(
# Async callback
# Async callback
async_callback
:
Optional
[
Callable
]
=
None
async_callback
:
Optional
[
Callable
]
=
None
# Optional tree attention mask from draft model.
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
# Optional tree position ids from draft model.
tree_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
@
property
@
property
def
is_first_multi_step
(
self
)
->
bool
:
def
is_first_multi_step
(
self
)
->
bool
:
# TODO(will) make this be able to handle batches with variable number of
# TODO(will) make this be able to handle batches with variable number of
...
@@ -1346,8 +1412,11 @@ class ExecuteModelRequest(
...
@@ -1346,8 +1412,11 @@ class ExecuteModelRequest(
num_lookahead_slots
=
self
.
num_lookahead_slots
,
num_lookahead_slots
=
self
.
num_lookahead_slots
,
running_queue_size
=
self
.
running_queue_size
,
running_queue_size
=
self
.
running_queue_size
,
previous_hidden_states
=
self
.
previous_hidden_states
,
previous_hidden_states
=
self
.
previous_hidden_states
,
previous_logits
=
self
.
previous_logits
,
num_steps
=
self
.
num_steps
,
num_steps
=
self
.
num_steps
,
finished_requests_ids
=
self
.
finished_requests_ids
,
finished_requests_ids
=
self
.
finished_requests_ids
,
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
async_callback
=
self
.
async_callback
)
async_callback
=
self
.
async_callback
,
tree_attn_masks
=
self
.
tree_attn_masks
,
tree_position_ids
=
self
.
tree_position_ids
)
vllm/spec_decode/batch_expansion.py
View file @
19bc93d9
...
@@ -112,6 +112,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -112,6 +112,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_ids
=
all_tokens
,
token_ids
=
all_tokens
,
logprobs
=
spec_logprobs
,
logprobs
=
spec_logprobs
,
hidden_states
=
all_hidden_states
,
hidden_states
=
all_hidden_states
,
logits
=
target_sampler_output
.
logits
,
)
)
def
_expand_batch
(
def
_expand_batch
(
...
@@ -460,3 +461,191 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -460,3 +461,191 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_ids_to_score
.
extend
(
full_spec_token_ids
[:
i
+
1
]
token_ids_to_score
.
extend
(
full_spec_token_ids
[:
i
+
1
]
for
i
in
range
(
len
(
full_spec_token_ids
)))
for
i
in
range
(
len
(
full_spec_token_ids
)))
return
token_ids_to_score
return
token_ids_to_score
class
BatchExpansionTreeStyleScorer
(
BatchExpansionTop1Scorer
):
def
__init__
(
self
,
scorer_worker
:
WorkerBase
,
device
:
str
,
vocab_size
:
int
):
super
().
__init__
(
scorer_worker
,
device
,
vocab_size
)
def
_contract_batch
(
self
,
contracted_bs
:
int
,
target_sampler_output
:
SamplerOutput
,
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
,
non_spec_target_hidden_states
)
=
self
.
_split_scoring_output
(
target_sampler_output
,
num_scoring_tokens
)
# Map distinct sequences used to score each token
# of shape [batch_size * k] back to [batch_size, k].
expanded_batch_size
,
k
=
proposals
.
proposal_token_ids
.
shape
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
non_spec_expanded_bs
=
len
(
non_spec_target_token_ids
)
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
target_token_ids
=
target_token_ids
.
reshape
(
spec_expanded_bs
,
k
)
target_probs
=
target_probs
.
reshape
(
*
target_token_ids
.
shape
,
self
.
_vocab_size
)
target_logprobs
=
target_logprobs
.
reshape
(
target_probs
.
shape
)
if
target_hidden_states
is
not
None
:
target_hidden_states
=
target_hidden_states
.
reshape
(
*
target_token_ids
.
shape
,
target_hidden_states
.
shape
[
-
1
])
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
contracted_bs
,
k
),
fill_value
=-
1
)
all_probs
=
target_probs
.
new_zeros
(
*
all_tokens
.
shape
,
self
.
_vocab_size
)
all_logprobs
=
target_logprobs
.
new_full
(
size
=
all_probs
.
shape
,
fill_value
=-
float
(
"inf"
))
if
target_sampler_output
.
hidden_states
is
not
None
:
all_hidden_states
=
target_hidden_states
.
new_zeros
(
size
=
(
contracted_bs
,
k
,
target_hidden_states
.
shape
[
-
1
]))
else
:
all_hidden_states
=
None
if
non_spec_indices
:
all_tokens
[
non_spec_indices
,
:
1
]
=
\
non_spec_target_token_ids
.
unsqueeze
(
1
)
all_probs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_target_probs
.
unsqueeze
(
1
)
all_logprobs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_target_logprobs
.
unsqueeze
(
1
)
if
all_hidden_states
is
not
None
:
assert
non_spec_target_hidden_states
is
not
None
all_hidden_states
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_target_hidden_states
.
unsqueeze
(
1
)
if
spec_indices
:
all_tokens
[
spec_indices
]
=
target_token_ids
all_probs
[
spec_indices
]
=
target_probs
all_logprobs
[
spec_indices
]
=
target_logprobs
if
all_hidden_states
is
not
None
:
all_hidden_states
[
spec_indices
]
=
target_hidden_states
return
all_tokens
,
all_probs
,
all_logprobs
,
all_hidden_states
def
_contract_batch_all_spec
(
self
,
target_sampler_output
:
SamplerOutput
,
proposals
:
SpeculativeProposals
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs
,
k
=
proposals
.
proposal_token_ids
.
shape
# Reshape tensors to original batch size
target_token_ids
=
target_sampler_output
.
sampled_token_ids
.
reshape
(
contracted_bs
,
k
)
target_probs
=
target_sampler_output
.
sampled_token_probs
.
reshape
(
*
target_token_ids
.
shape
,
self
.
_vocab_size
)
target_logprobs
=
target_sampler_output
.
logprobs
.
reshape
(
target_probs
.
shape
)
target_hidden_states
=
target_sampler_output
.
hidden_states
if
target_hidden_states
is
not
None
:
target_hidden_states
=
target_hidden_states
.
reshape
(
*
target_token_ids
.
shape
,
target_hidden_states
.
shape
[
-
1
])
return
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
)
@
staticmethod
def
_create_single_target_seq_group_metadata
(
seq_group_metadata
:
SequenceGroupMetadata
,
seq_id
:
SeqId
,
target_seq_id
:
TargetSeqId
,
token_ids
:
List
[
TokenId
],
sampling_params
:
SamplingParams
,
)
->
SequenceGroupMetadata
:
"""Create a single target SequenceGroupMetadata.
Args:
seq_group_metadata: The metadata for the input sequence.
seq_id: The input sequence ID.
target_seq_id: The corresponding target sequence ID.
token_ids: The list of token ids that are to be appended to the
input sequence.
"""
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_token_ids
=
seq_data
.
prompt_token_ids_array
# first step need to ignore output token generated by prefill phase
if
len
(
seq_data
.
get_output_token_ids
())
==
1
:
new_output_token_ids
=
[
*
seq_data
.
get_output_token_ids
()[:
-
1
],
*
token_ids
]
else
:
new_output_token_ids
=
[
*
seq_data
.
get_output_token_ids
(),
*
token_ids
]
new_seq_data_dict
=
{
target_seq_id
:
SequenceData
(
prompt_token_ids
,
_output_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
new_output_token_ids
),
),
}
# This is a hack. Technically, spec decoding should compute
# num_lookahead slots at one shot, but instead, it expands the batch
# and evaluate one by one right now. context_len is seq_len - 1 because
# the kv cache is filled by a previous batch in the batch expansion.
for
data
in
new_seq_data_dict
.
values
():
data
.
update_num_computed_tokens
(
data
.
get_len
()
-
1
)
return
SequenceGroupMetadata
(
request_id
=
seq_group_metadata
.
request_id
,
is_prompt
=
seq_group_metadata
.
is_prompt
,
seq_data
=
new_seq_data_dict
,
sampling_params
=
sampling_params
,
block_tables
=
{
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
},
lora_request
=
None
,
token_chunk_size
=
1
,
)
def
_get_token_ids_to_score
(
self
,
full_spec_token_ids
:
List
[
TokenId
]
# shape: [k]
)
->
List
[
List
[
TokenId
]]:
"""Given an int tensor of proposal token ids, return a list of
token ids that should be scored.
Returns k+1 output lists. The additional one is used for generating the
bonus token.
Example:
Input: [0, 1, 2, 3] (k=4)
Output: (k+1 lists)
[0]
[0, 1]
[0, 1, 2]
[0, 1, 2, 3]
"""
token_ids_to_score
=
[]
token_ids_to_score
.
extend
([
full_spec_token_ids
[:
i
+
1
]
for
i
in
range
(
len
(
full_spec_token_ids
))
])
return
token_ids_to_score
vllm/spec_decode/interfaces.py
View file @
19bc93d9
...
@@ -25,6 +25,19 @@ class SpeculativeProposals:
...
@@ -25,6 +25,19 @@ class SpeculativeProposals:
# A flag to mark that there's no available proposals
# A flag to mark that there's no available proposals
no_proposals
:
bool
=
False
no_proposals
:
bool
=
False
# The cart_candidates used in tree-style generation
cart_candidates
:
Optional
[
torch
.
Tensor
]
=
None
# The cart_candidates used in tree-style generation
retrieve_indices
:
Optional
[
torch
.
Tensor
]
=
None
# tree-style attention masks
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
# tree-style cartesian candidates
tree_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
f
"SpeculativeProposals("
return
(
f
"SpeculativeProposals("
f
"proposal_token_ids=
{
self
.
proposal_token_ids
}
, "
f
"proposal_token_ids=
{
self
.
proposal_token_ids
}
, "
...
@@ -53,6 +66,9 @@ class SpeculativeScores:
...
@@ -53,6 +66,9 @@ class SpeculativeScores:
# Optional last hidden states from the scoring model.
# Optional last hidden states from the scoring model.
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
# Optional lm_head logits from the scoring model.
logits
:
Optional
[
torch
.
Tensor
]
=
None
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
f
"SpeculativeScores("
return
(
f
"SpeculativeScores("
f
"probs=
{
self
.
probs
.
shape
}
, "
f
"probs=
{
self
.
probs
.
shape
}
, "
...
...
vllm/spec_decode/medusa_worker.py
View file @
19bc93d9
...
@@ -2,35 +2,63 @@ import weakref
...
@@ -2,35 +2,63 @@ import weakref
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
,
SpeculativeProposer
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.tree_style_proposer
import
TreeStyleProposer
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
TOPK
=
10
# topk for sparse tree (10 is a placeholder and it is sufficient)
class
MedusaWorker
(
NonLLMProposerWorkerBase
,
Worker
):
class
MedusaWorker
(
NonLLMProposerWorkerBase
,
Worker
):
"""Worker for Medusa.
"""Worker for Medusa.
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
# skip lora config in medusa
kwargs_copy
=
kwargs
.
copy
()
kwargs_copy
[
'lora_config'
]
=
None
super
().
__init__
(
*
args
,
**
kwargs_copy
)
# Lazy initialization list.
# Lazy initialization list.
self
.
_proposer
:
Top1
Proposer
self
.
_proposer
:
Speculative
Proposer
def
init_device
(
self
):
def
init_device
(
self
):
super
().
init_device
()
super
().
init_device
()
def
load_model
(
self
):
super
().
load_model
()
# get medusa choices and generate medusa_buffers
self
.
medusa_buffers
=
None
if
hasattr
(
self
.
model_runner
.
model
,
'medusa_choices'
):
self
.
medusa_choices
=
self
.
model_runner
.
model
.
medusa_choices
self
.
medusa_buffers
=
self
.
generate_medusa_buffers
(
self
.
medusa_choices
,
device
=
self
.
device
)
if
self
.
medusa_buffers
is
None
:
self
.
_proposer
=
Top1Proposer
(
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
device
,
self
.
vocab_size
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
max_proposal_len
=
self
.
max_model_len
,
)
)
else
:
self
.
_proposer
=
TreeStyleProposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
vocab_size
,
self
.
medusa_buffers
,
max_proposal_len
=
self
.
max_model_len
,
)
def
set_include_gpu_probs_tensor
(
self
):
def
set_include_gpu_probs_tensor
(
self
):
pass
pass
...
@@ -69,7 +97,24 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
...
@@ -69,7 +97,24 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
hidden_states
,
hidden_states
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
,
previous_logits
=
execute_model_req
.
previous_logits
.
logits
if
execute_model_req
.
previous_logits
is
not
None
else
None
,
medusa_buffers
=
self
.
medusa_buffers
)
# create tree attn masks
if
self
.
medusa_buffers
is
not
None
:
max_context_len
=
max
(
seq_lens
)
for
sampler_output
,
seq_len
in
zip
(
model_outputs
,
seq_lens
):
context_len
=
seq_len
attn_masks
=
self
.
medusa_buffers
[
'tree_attn_masks'
]
left_mask
=
torch
.
ones
(
attn_masks
.
shape
[
0
],
context_len
,
dtype
=
attn_masks
.
dtype
,
device
=
attn_masks
.
device
)
attn_masks
=
torch
.
cat
([
left_mask
,
attn_masks
],
dim
=-
1
)
right_pad
=
max_context_len
-
context_len
if
right_pad
>
0
:
attn_masks
=
F
.
pad
(
attn_masks
,
(
0
,
right_pad
),
"constant"
,
0
)
sampler_output
.
tree_attn_masks
=
attn_masks
return
model_outputs
,
False
return
model_outputs
,
False
...
@@ -96,6 +141,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
...
@@ -96,6 +141,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
query_lens
.
append
(
seq_len
-
context_len
)
query_lens
.
append
(
seq_len
-
context_len
)
else
:
else
:
# first step of tree decoding need to ignore first token
if
self
.
medusa_buffers
is
not
None
and
seq_data
.
get_output_len
()
==
1
:
seq_data_len
-=
1
seq_lens
.
append
(
seq_data_len
)
seq_lens
.
append
(
seq_data_len
)
query_lens
.
append
(
1
)
query_lens
.
append
(
1
)
...
@@ -134,3 +182,124 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
...
@@ -134,3 +182,124 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
execute_model_req
.
seq_group_metadata_list
):
execute_model_req
.
seq_group_metadata_list
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"MedusaWorker does not support beam search."
)
"MedusaWorker does not support beam search."
)
def
pad_path
(
self
,
path
,
length
,
pad_value
=-
2
):
"""
Pad the given path list with a specific value up to a specified length.
Parameters:
- path (list): The original list that needs padding.
- length (int): The desired length of the padded list.
- pad_value (optional, default=-2): The value to use for padding.
Returns:
- list: A new list based on the original path but padded to the desired length.
Example:
>>> pad_path([1,2,3], 5)
[1, 2, 3, -2, -2]
Note:
If the given path is already longer than the specified length,
then no padding occurs, and the original path is returned.
"""
# Calculate the number of padding values needed by subtracting the length
# of the path from the desired length.
# Append the padding values to the original path and return the new list.
return
path
+
[
pad_value
]
*
(
length
-
len
(
path
))
def
generate_medusa_buffers
(
self
,
medusa_choices
,
device
=
"cuda"
):
"""
Generate buffers for the Medusa structure based on the provided choices.
Parameters:
- medusa_choices (list): A nested list representing tree in the Medusa structure.
- device (str): Device to which the tensors should be moved. Default is "cuda".
Returns:
- dict: A dictionary containing buffers related to the Medusa structure.
"""
# Sort the medusa_choices based on their lengths and then their values
sorted_medusa_choices
=
sorted
(
medusa_choices
,
key
=
lambda
x
:
(
len
(
x
),
x
))
medusa_len
=
len
(
sorted_medusa_choices
)
+
1
# Initialize depth_counts to keep track of how many choices have a particular depth
depth_counts
=
[]
prev_depth
=
0
for
path
in
sorted_medusa_choices
:
depth
=
len
(
path
)
if
depth
!=
prev_depth
:
depth_counts
.
append
(
0
)
depth_counts
[
depth
-
1
]
+=
1
prev_depth
=
depth
# Create the attention mask for Medusa
medusa_attn_mask
=
torch
.
eye
(
medusa_len
,
medusa_len
)
medusa_attn_mask
[:,
0
]
=
1
start
=
0
for
i
in
range
(
len
(
depth_counts
)):
for
j
in
range
(
depth_counts
[
i
]):
cur_medusa_choice
=
sorted_medusa_choices
[
start
+
j
]
# retrieve ancestor position
if
len
(
cur_medusa_choice
)
==
1
:
continue
ancestor_idx
=
[]
for
c
in
range
(
len
(
cur_medusa_choice
)
-
1
):
ancestor_idx
.
append
(
sorted_medusa_choices
.
index
(
cur_medusa_choice
[:
c
+
1
])
+
1
)
medusa_attn_mask
[
j
+
start
+
1
,
ancestor_idx
]
=
1
start
+=
depth_counts
[
i
]
# Generate tree indices for the Medusa structure
medusa_tree_indices
=
torch
.
zeros
(
medusa_len
,
dtype
=
torch
.
long
)
medusa_tree_indices
[
0
]
=
0
start
=
0
for
i
in
range
(
len
(
depth_counts
)):
for
j
in
range
(
depth_counts
[
i
]):
cur_medusa_choice
=
sorted_medusa_choices
[
start
+
j
]
medusa_tree_indices
[
start
+
j
+
1
]
=
cur_medusa_choice
[
-
1
]
+
TOPK
*
i
+
1
start
+=
depth_counts
[
i
]
# Generate position IDs for the Medusa structure
medusa_position_ids
=
torch
.
zeros
(
medusa_len
,
dtype
=
torch
.
long
)
start
=
0
for
i
in
range
(
len
(
depth_counts
)):
medusa_position_ids
[
start
+
1
:
start
+
depth_counts
[
i
]
+
1
]
=
i
+
1
start
+=
depth_counts
[
i
]
# Generate retrieval indices for Medusa structure verification
retrieve_indices_nest
=
[]
retrieve_paths
=
[]
for
i
in
range
(
len
(
sorted_medusa_choices
)):
cur_medusa_choice
=
sorted_medusa_choices
[
-
i
-
1
]
retrieve_indice
=
[]
if
cur_medusa_choice
in
retrieve_paths
:
continue
else
:
for
c
in
range
(
len
(
cur_medusa_choice
)):
retrieve_indice
.
append
(
sorted_medusa_choices
.
index
(
cur_medusa_choice
[:
c
+
1
]))
retrieve_paths
.
append
(
cur_medusa_choice
[:
c
+
1
])
retrieve_indices_nest
.
append
(
retrieve_indice
)
max_length
=
max
([
len
(
x
)
for
x
in
retrieve_indices_nest
])
retrieve_indices
=
[
self
.
pad_path
(
path
,
max_length
)
for
path
in
retrieve_indices_nest
]
retrieve_indices
=
torch
.
tensor
(
retrieve_indices
,
dtype
=
torch
.
long
)
retrieve_indices
=
retrieve_indices
+
1
retrieve_indices
=
torch
.
cat
([
torch
.
zeros
((
retrieve_indices
.
shape
[
0
],
1
),
dtype
=
torch
.
long
),
retrieve_indices
],
dim
=
1
)
# Aggregate the generated buffers into a dictionary
medusa_buffers
=
{
"tree_attn_masks"
:
medusa_attn_mask
.
int
(),
"tree_indices"
:
medusa_tree_indices
,
"tree_position_ids"
:
medusa_position_ids
,
"retrieve_indices"
:
retrieve_indices
,
}
# Move the tensors in the dictionary to the specified device
medusa_buffers
=
{
k
:
v
.
clone
().
to
(
device
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
torch
.
tensor
(
v
,
device
=
device
)
for
k
,
v
in
medusa_buffers
.
items
()
}
return
medusa_buffers
vllm/spec_decode/spec_decode_worker.py
View file @
19bc93d9
...
@@ -16,8 +16,8 @@ from vllm.model_executor.layers.typical_acceptance_sampler import (
...
@@ -16,8 +16,8 @@ from vllm.model_executor.layers.typical_acceptance_sampler import (
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
HiddenStates
,
SequenceGroupMetadata
,
HiddenStates
,
SequenceGroupMetadata
,
get_all_seq_ids_and_request_ids
)
get_all_seq_ids_and_request_ids
,
Logits
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
,
BatchExpansionTreeStyleScorer
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
...
@@ -36,6 +36,8 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
...
@@ -36,6 +36,8 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
split_batch_by_proposal_len
)
split_batch_by_proposal_len
)
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
,
WorkerBase
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
,
WorkerBase
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.attention.ops.paged_attn
import
PagedAttention
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -80,6 +82,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
...
@@ -80,6 +82,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_alpha
,
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
speculative_config
.
disable_logprobs
,
disable_logprobs
=
speculative_config
.
disable_logprobs
,
disable_log_stats
=
speculative_config
.
disable_log_stats
,
disable_log_stats
=
speculative_config
.
disable_log_stats
,
tree_style_spec_decoding
=
speculative_config
.
tree_style_spec_decoding
,
)
)
return
spec_decode_worker
return
spec_decode_worker
...
@@ -122,6 +125,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -122,6 +125,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_alpha
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
disable_log_stats
:
bool
,
tree_style_spec_decoding
:
bool
,
)
->
"SpecDecodeWorker"
:
)
->
"SpecDecodeWorker"
:
allow_zero_draft_token_step
=
True
allow_zero_draft_token_step
=
True
...
@@ -183,7 +187,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -183,7 +187,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_log_stats
=
disable_log_stats
,
disable_log_stats
=
disable_log_stats
,
disable_by_batch_size
=
disable_by_batch_size
,
disable_by_batch_size
=
disable_by_batch_size
,
spec_decode_sampler
=
spec_decode_sampler
,
spec_decode_sampler
=
spec_decode_sampler
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
)
allow_zero_draft_token_step
=
allow_zero_draft_token_step
,
tree_style_spec_decoding
=
tree_style_spec_decoding
)
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -195,6 +200,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -195,6 +200,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
tree_style_spec_decoding
:
bool
=
False
,
):
):
"""
"""
Create a SpecDecodeWorker.
Create a SpecDecodeWorker.
...
@@ -223,6 +229,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -223,6 +229,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
allow_zero_draft_token_step: whether to allow a step where the draft
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
draft model is larger than 1 (TODO: #5814)
tree_style_spec_decoding: Whether to use tree-style generation.
"""
"""
self
.
proposer_worker
=
proposer_worker
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
self
.
scorer_worker
=
scorer_worker
...
@@ -247,14 +254,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -247,14 +254,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
probs_dtype
=
self
.
spec_decode_sampler
.
probs_dtype
self
.
probs_dtype
=
self
.
spec_decode_sampler
.
probs_dtype
self
.
token_id_dtype
=
self
.
spec_decode_sampler
.
token_id_dtype
self
.
token_id_dtype
=
self
.
spec_decode_sampler
.
token_id_dtype
# Lazy initialization.
# Lazy initialization.
self
.
scorer
:
Speculative
Scorer
self
.
scorer
:
BatchExpansionTop1
Scorer
# Hidden states from target model to pass to proposer
# Hidden states from target model to pass to proposer
# in the subsequent step.
# in the subsequent step.
self
.
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
self
.
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
self
.
previous_logits
:
Optional
[
Logits
]
=
None
self
.
_disable_logprobs
=
disable_logprobs
self
.
_disable_logprobs
=
disable_logprobs
self
.
_disable_log_stats
=
disable_log_stats
self
.
_disable_log_stats
=
disable_log_stats
self
.
tree_style_spec_decoding
=
tree_style_spec_decoding
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
"""Initialize both scorer and proposer models.
"""Initialize both scorer and proposer models.
"""
"""
...
@@ -270,10 +280,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -270,10 +280,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
spec_decode_sampler
.
init_gpu_tensors
(
self
.
rank
)
self
.
spec_decode_sampler
.
init_gpu_tensors
(
self
.
rank
)
if
not
self
.
tree_style_spec_decoding
:
self
.
scorer
=
BatchExpansionTop1Scorer
(
self
.
scorer
=
BatchExpansionTop1Scorer
(
scorer_worker
=
self
.
scorer_worker
,
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
device
=
self
.
device
,
vocab_size
=
self
.
_vocab_size
)
vocab_size
=
self
.
_vocab_size
)
else
:
self
.
scorer
=
BatchExpansionTreeStyleScorer
(
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
vocab_size
=
self
.
_vocab_size
)
self
.
_configure_model_sampler_for_spec_decode
()
self
.
_configure_model_sampler_for_spec_decode
()
...
@@ -532,6 +548,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -532,6 +548,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
previous_hidden_states
.
update
(
self
.
previous_hidden_states
.
update
(
hidden_states
,
execute_model_req
.
seq_group_metadata_list
)
hidden_states
,
execute_model_req
.
seq_group_metadata_list
)
# Store logits from target model execution.
logits
=
sampler_output
.
logits
if
logits
is
not
None
:
if
self
.
previous_logits
is
None
:
self
.
previous_logits
=
Logits
(
logits
,
execute_model_req
.
seq_group_metadata_list
)
else
:
self
.
previous_logits
.
update
(
logits
,
execute_model_req
.
seq_group_metadata_list
)
if
not
skip_proposer
:
if
not
skip_proposer
:
# We prepare the prefill hidden states here so that there no
# We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode
# additional complexity in worker for spec_decode vs non_spec_decode
...
@@ -605,6 +631,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -605,6 +631,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req
.
previous_hidden_states
=
self
.
previous_hidden_states
execute_model_req
.
previous_hidden_states
=
self
.
previous_hidden_states
self
.
previous_hidden_states
=
None
self
.
previous_hidden_states
=
None
# Pass last logits from target model to proposer
execute_model_req
.
previous_logits
=
self
.
previous_logits
self
.
previous_logits
=
None
with
Timer
()
as
proposal_timer
:
with
Timer
()
as
proposal_timer
:
# Generate proposals using draft worker.
# Generate proposals using draft worker.
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
...
@@ -615,6 +645,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -615,6 +645,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
raise
RuntimeError
(
"Cannot handle cases where distributed draft "
raise
RuntimeError
(
"Cannot handle cases where distributed draft "
"workers generate no tokens"
)
"workers generate no tokens"
)
# Pass tree attention mask and postions to target model
if
self
.
tree_style_spec_decoding
:
execute_model_req
.
tree_attn_masks
=
proposals
.
tree_attn_masks
execute_model_req
.
tree_position_ids
=
proposals
.
tree_position_ids
execute_model_req
.
previous_hidden_states
=
None
execute_model_req
.
previous_hidden_states
=
None
with
Timer
()
as
scoring_timer
:
with
Timer
()
as
scoring_timer
:
...
@@ -624,10 +659,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -624,10 +659,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
)
)
with
Timer
()
as
verification_timer
:
with
Timer
()
as
verification_timer
:
accepted_token_ids
,
target_logprobs
=
self
.
_verify_tokens
(
accepted_token_ids
,
target_logprobs
,
select_indices_list
,
accept_lengths
=
self
.
_verify_tokens
(
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
execute_model_req
.
seq_group_metadata_list
,
proposal_scores
,
proposals
,
execute_model_req
.
num_lookahead_slots
)
proposals
,
execute_model_req
.
num_lookahead_slots
)
# move kv_caches of selected tokens to right positions
if
self
.
tree_style_spec_decoding
:
self
.
move_caches
(
execute_model_req
,
select_indices_list
,
accept_lengths
)
stage_times
=
(
proposal_timer
.
elapsed_time_ms
/
num_lookahead_slots
,
stage_times
=
(
proposal_timer
.
elapsed_time_ms
/
num_lookahead_slots
,
scoring_timer
.
elapsed_time_ms
,
scoring_timer
.
elapsed_time_ms
,
verification_timer
.
elapsed_time_ms
)
verification_timer
.
elapsed_time_ms
)
...
@@ -646,7 +685,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -646,7 +685,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores
:
SpeculativeScores
,
proposal_scores
:
SpeculativeScores
,
proposals
:
SpeculativeProposals
,
proposals
:
SpeculativeProposals
,
max_proposal_len
:
int
,
max_proposal_len
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
torch
.
Tensor
],
List
[
int
]
]:
"""Determine which speculative tokens are accepted using the
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
probabilities of each token according to the proposer and scorer models.
...
@@ -666,6 +705,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -666,6 +705,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get probabilities of target model, including bonus tokens.
# Get probabilities of target model, including bonus tokens.
proposal_verifier_probs
=
proposal_scores
.
probs
[
spec_indices
]
proposal_verifier_probs
=
proposal_scores
.
probs
[
spec_indices
]
if
self
.
tree_style_spec_decoding
:
retrieve_indices
=
proposals
.
retrieve_indices
proposal_verifier_probs
=
proposal_verifier_probs
[:,
retrieve_indices
]
# Get non-speculative sampled tokens from target model.
# Get non-speculative sampled tokens from target model.
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
...
@@ -673,11 +716,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -673,11 +716,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
bonus_token_ids
=
proposal_scores
.
token_ids
[
spec_indices
,
-
1
:]
bonus_token_ids
=
proposal_scores
.
token_ids
[
spec_indices
,
-
1
:]
# Get probabilities according to proposal method.
# Get probabilities according to proposal method.
proposal_probs
=
proposals
.
proposal_probs
[
spec_indices
]
proposal_probs
=
proposals
.
proposal_probs
[
spec_indices
]
\
if
proposals
.
proposal_probs
is
not
None
else
None
# Get proposed tokens.
# Get proposed tokens.
proposal_token_ids
=
proposals
.
proposal_token_ids
[
spec_indices
]
proposal_token_ids
=
proposals
.
proposal_token_ids
[
spec_indices
]
# Get tree buffers.
cart_candidates
=
proposals
.
cart_candidates
[
spec_indices
]
if
proposals
.
cart_candidates
is
not
None
else
None
# Sampler arguments
# Sampler arguments
sampler_extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
sampler_extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
self
.
generators
and
isinstance
(
self
.
spec_decode_sampler
,
if
self
.
generators
and
isinstance
(
self
.
spec_decode_sampler
,
...
@@ -688,6 +735,18 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -688,6 +735,18 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if
sgm
.
sampling_params
.
seed
is
not
None
if
sgm
.
sampling_params
.
seed
is
not
None
}
}
if
isinstance
(
self
.
spec_decode_sampler
,
TypicalAcceptanceSampler
):
sampler_extra_kwargs
[
"cart_candidates"
]
=
cart_candidates
sampler_extra_kwargs
[
"best_candidates"
]
=
[]
sampler_extra_kwargs
[
"accept_lengths"
]
=
[]
first_step_flags
=
[]
for
i
,
sgm
in
enumerate
(
seq_group_metadata_list
):
seq
=
next
(
iter
(
sgm
.
seq_data
.
values
()))
first_step_flags
.
append
(
True
if
seq
.
get_output_len
()
==
1
else
False
)
sampler_extra_kwargs
[
"first_step_flags"
]
=
first_step_flags
accepted_token_ids
=
self
.
spec_decode_sampler
(
accepted_token_ids
=
self
.
spec_decode_sampler
(
target_with_bonus_probs
=
proposal_verifier_probs
,
target_with_bonus_probs
=
proposal_verifier_probs
,
bonus_token_ids
=
bonus_token_ids
,
bonus_token_ids
=
bonus_token_ids
,
...
@@ -697,8 +756,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -697,8 +756,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
)
)
# Append output tokens from non-speculative sequences to
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
# the accepted token ids tensor.
if
not
self
.
tree_style_spec_decoding
:
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
+
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
+
1
).
clone
()
1
).
clone
()
else
:
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
).
clone
()
non_spec_token_ids
[:,
1
:]
=
-
1
non_spec_token_ids
[:,
1
:]
=
-
1
accepted_token_ids
=
torch
.
cat
(
accepted_token_ids
=
torch
.
cat
(
[
accepted_token_ids
,
non_spec_token_ids
])
[
accepted_token_ids
,
non_spec_token_ids
])
...
@@ -708,6 +771,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -708,6 +771,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids
[
original_indices
]
=
accepted_token_ids
.
clone
()
accepted_token_ids
[
original_indices
]
=
accepted_token_ids
.
clone
()
hidden_states
=
proposal_scores
.
hidden_states
hidden_states
=
proposal_scores
.
hidden_states
select_indices
=
None
accept_lengths
=
None
select_indices_list
=
[]
if
cart_candidates
is
None
:
if
hidden_states
is
not
None
:
if
hidden_states
is
not
None
:
# Contract hidden states based on accepted tokens
# Contract hidden states based on accepted tokens
hs_size
=
hidden_states
.
shape
[
-
1
]
hs_size
=
hidden_states
.
shape
[
-
1
]
...
@@ -721,8 +791,129 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -721,8 +791,129 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
previous_hidden_states
=
HiddenStates
(
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
seq_group_metadata_list
,
hidden_states
,
seq_group_metadata_list
,
second_last_token_hidden_states
)
second_last_token_hidden_states
)
else
:
retrieve_indices
=
proposals
.
retrieve_indices
batch_size
=
len
(
seq_group_metadata_list
)
best_candidates
=
sampler_extra_kwargs
[
"best_candidates"
]
accept_lengths
=
sampler_extra_kwargs
[
"accept_lengths"
]
# Contract hidden states based on accepted tokens
hs_size
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
batch_size
,
-
1
,
hs_size
)
# Store logits from target model for subsequent proposal
logits
=
proposal_scores
.
logits
logits
=
logits
.
view
(
batch_size
,
-
1
,
logits
.
shape
[
-
1
])
logits
=
logits
[:,
retrieve_indices
]
# [batch_size, retrieve_size, max_depth, vocab_size]
previous_logits_list
=
[]
previous_hidden_state_list
=
[]
for
i
in
range
(
batch_size
):
logit
=
logits
[
i
,
best_candidates
[
i
],
accept_lengths
[
i
]].
unsqueeze
(
0
)
previous_logits_list
.
append
(
logit
)
select_indices
=
retrieve_indices
[
best_candidates
[
i
],
:
accept_lengths
[
i
]
+
1
]
hidden_state
=
hidden_states
[
i
,
select_indices
[
-
1
]].
unsqueeze
(
0
)
select_indices_list
.
append
(
select_indices
)
previous_hidden_state_list
.
append
(
hidden_state
)
logits
=
torch
.
cat
(
previous_logits_list
,
dim
=
0
)
self
.
previous_logits
=
Logits
(
logits
,
seq_group_metadata_list
)
hidden_states
=
torch
.
cat
(
previous_hidden_state_list
,
dim
=
0
)
# [batch_size, 1, vocab_size]
self
.
previous_hidden_states
=
HiddenStates
(
hidden_states
,
seq_group_metadata_list
,)
return
accepted_token_ids
,
logprobs
,
select_indices_list
,
accept_lengths
def
move_caches
(
self
,
execute_model_req
:
ExecuteModelRequest
,
select_indices_list
:
List
[
torch
.
Tensor
],
accept_lengths
:
List
[
int
]):
"""Given selected output tokens and accept length,
move kv_caches of selected tokens to right positions.
"""
seq_lens
=
[]
for
sg
in
execute_model_req
.
seq_group_metadata_list
:
seq_ids
=
list
(
sg
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
sg
.
seq_data
[
seq_id
]
seq_len
=
seq_data
.
get_len
()
token_chunk_size
=
sg
.
token_chunk_size
context_len
=
seq_len
-
1
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
# first step of tree-style decoding need to ignore first generated token
if
seq_data
.
get_output_len
()
==
1
:
seq_len
-=
1
seq_lens
.
append
(
seq_len
)
model_input
=
self
.
scorer
.
_scorer_worker
.
model_input
block_tables
=
None
if
hasattr
(
model_input
,
'attn_metadata'
)
and
hasattr
(
model_input
.
attn_metadata
,
'block_tables'
):
block_tables
=
model_input
.
attn_metadata
.
block_tables
if
block_tables
is
None
:
raise
RuntimeError
(
"Can not get block_tables from model_input."
)
block_tables
=
block_tables
.
cpu
().
tolist
()
cache_engine
=
self
.
scorer
.
_scorer_worker
.
cache_engines
[
execute_model_req
.
virtual_engine
]
block_size
=
cache_engine
.
block_size
batch_size
=
len
(
select_indices_list
)
block_table_stride
=
len
(
block_tables
)
//
batch_size
select_indices_slot_mapping
=
[]
target_slot_mapping
=
[]
for
i
in
range
(
batch_size
):
accept_legth
=
accept_lengths
[
i
]
if
accept_legth
>
0
:
select_indices
=
select_indices_list
[
i
][
1
:]
+
seq_lens
[
i
]
self
.
compute_slot_mapping
(
select_indices_slot_mapping
,
i
*
block_table_stride
,
select_indices
,
block_size
,
block_tables
)
target_indices
=
torch
.
arange
(
accept_legth
+
1
)[
1
:]
+
seq_lens
[
i
]
self
.
compute_slot_mapping
(
target_slot_mapping
,
i
*
block_table_stride
,
target_indices
,
block_size
,
block_tables
)
if
len
(
select_indices_slot_mapping
)
>
0
:
select_indices_slot_tensor
=
torch
.
tensor
(
select_indices_slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
).
view
(
-
1
,
1
)
target_slot_mapping_tensor
=
torch
.
tensor
(
target_slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
).
view
(
-
1
,
1
)
src_dst_tensor
=
torch
.
cat
([
select_indices_slot_tensor
,
target_slot_mapping_tensor
],
dim
=-
1
)
#[batch_size*T, 2]
kv_caches
=
self
.
scorer
.
_scorer_worker
.
kv_cache
[
execute_model_req
.
virtual_engine
]
kv_cache_dtype
=
cache_engine
.
cache_config
.
cache_dtype
backend
=
cache_engine
.
attn_backend
num_kv_heads
=
cache_engine
.
num_kv_heads
head_size
=
cache_engine
.
head_size
backend
.
move_cache
(
kv_caches
,
src_dst_tensor
,
kv_cache_dtype
,
num_kv_heads
,
head_size
)
def
compute_slot_mapping
(
self
,
slot_mapping
:
List
[
int
],
seq_id
:
int
,
select_indices
:
List
[
int
],
block_size
:
int
,
block_tables
:
List
[
List
[
int
]]):
"""
Compute slot mapping.
"""
# Mask the [0, start_idx) tokens of the prompt with
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table
=
block_tables
[
seq_id
]
for
index
in
select_indices
:
block_number
=
block_table
[
index
//
block_size
]
block_offset
=
index
%
block_size
slot
=
block_number
*
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
return
accepted_token_ids
,
logprobs
def
_create_output_sampler_list
(
def
_create_output_sampler_list
(
self
,
self
,
...
...
vllm/spec_decode/tree_style_proposer.py
0 → 100644
View file @
19bc93d9
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Any
,
Dict
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
ExecuteModelRequest
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.util
import
sampler_output_to_torch
class
TreeStyleProposer
(
SpeculativeProposer
):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def
__init__
(
self
,
worker
:
ProposerWorkerBase
,
device
:
str
,
vocab_size
:
int
,
tree_buffers
:
Dict
[
str
,
Any
],
max_proposal_len
:
Optional
[
int
]
=
None
,
):
self
.
_worker
=
worker
self
.
_device
=
device
self
.
tree_buffers
=
tree_buffers
self
.
max_proposal_len
=
max_proposal_len
self
.
_vocab_size
=
vocab_size
def
get_spec_proposals
(
self
,
execute_model_req
:
ExecuteModelRequest
,
seq_ids_with_bonus_token_in_last_step
:
Set
[
int
],
)
->
SpeculativeProposals
:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
#proposal_len = execute_model_req.num_lookahead_slots
proposal_len
=
self
.
tree_buffers
[
"tree_indices"
].
shape
[
0
]
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
,
)
=
self
.
_split_by_proposal_len
(
seq_group_metadata_list
,
proposal_len
)
if
nonzero_proposal_len_seqs
:
# Speculate tokens using the draft worker for the speculative
# sequences.
# If sampler_transposed is true, then maybe_sampler_output's
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
hidden_states
=
execute_model_req
.
previous_hidden_states
if
hidden_states
is
not
None
:
hidden_states
.
prune
(
nonzero_proposal_len_seqs
)
logits
=
execute_model_req
.
previous_logits
if
logits
is
not
None
:
logits
.
prune
(
nonzero_proposal_len_seqs
)
nonzero_execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
nonzero_proposal_len_seqs
,
num_lookahead_slots
=
proposal_len
,
previous_hidden_states
=
hidden_states
,
previous_logits
=
logits
,
)
maybe_sampler_output
,
transposed
=
self
.
_worker
.
sampler_output
(
execute_model_req
=
nonzero_execute_model_req
,
sample_len
=
proposal_len
,
seq_ids_with_bonus_token_in_last_step
=
\
seq_ids_with_bonus_token_in_last_step
,
)
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
)
=
self
.
_remove_no_proposal_seqs
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
transposed
)
else
:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output
=
None
transposed
=
False
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens
,
proposal_probs
,
proposal_lens
,
cart_candidates
,
tree_attn_masks
=
self
.
_merge_outputs
(
batch_size
=
len
(
seq_group_metadata_list
),
proposal_len
=
proposal_len
,
maybe_sampler_output
=
maybe_sampler_output
,
proposal_lens
=
proposal_lens
,
nonzero_proposal_len_indices
=
nonzero_proposal_len_indices
,
sampler_transposed
=
transposed
,
)
tree_position_ids_list
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
if
seq_data
.
get_output_len
()
==
1
:
seq_len
=
seq_data
.
get_len
()
-
1
else
:
seq_len
=
seq_data
.
get_len
()
tree_position_ids
=
self
.
tree_buffers
[
'tree_position_ids'
]
+
seq_len
tree_position_ids_list
.
append
(
tree_position_ids
)
tree_position_ids
=
torch
.
stack
(
tree_position_ids_list
,
dim
=
0
).
reshape
(
-
1
,
1
)
proposals
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_tokens
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
,
no_proposals
=
maybe_sampler_output
is
None
,
cart_candidates
=
cart_candidates
,
retrieve_indices
=
self
.
tree_buffers
[
'retrieve_indices'
],
tree_attn_masks
=
tree_attn_masks
,
tree_position_ids
=
tree_position_ids
)
return
proposals
def
_split_by_proposal_len
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
SequenceGroupMetadata
],
List
[
int
]]:
"""Split sequences by two groups:
1. Sequences with non-zero proposal length.
2. Sequences with zero proposal length (due to disabled speculation
or exceed the maximum model length).
"""
proposal_lens
:
List
[
int
]
=
[]
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
# The speculative decoding for this request has been disabled
# (e.g. due to high traffic).
if
seq_group_metadata
.
num_speculative_tokens
==
0
:
proposal_lens
.
append
(
0
)
continue
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_len
=
seq_data
.
get_len
()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall no exceed this
# quota for nonzero_proposal
new_k
=
0
if
(
self
.
max_proposal_len
is
None
or
seq_len
+
proposal_len
<
self
.
max_proposal_len
):
new_k
=
proposal_len
nonzero_proposal_len_seqs
.
append
(
seq_group_metadata
)
nonzero_proposal_len_indices
.
append
(
i
)
proposal_lens
.
append
(
new_k
)
seq_group_metadata
.
num_speculative_tokens
=
new_k
return
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
,
)
@
staticmethod
def
_remove_no_proposal_seqs
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
transposed
):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
(maybe_sampler_output=None). This can avoid scoring overheads.
"""
# If maybe_sampler_output is None, then the draft worker did not
# provide a proposal for any sequence and thus no action needed.
# Also we do not support transposed maybe_sampler_output for now
# because it seems not straightforward for draft workers outputting
# transposed sampler outputs to handle the case of no proposal.
if
maybe_sampler_output
is
None
or
transposed
:
return
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
)
new_proposal_lens
:
List
[
int
]
=
[]
new_nonzero_proposal_len_indices
:
List
[
int
]
=
[]
new_maybe_sampler_output
:
List
[
SamplerOutput
]
=
[]
nonzero_proposal_len_idx_ptr
=
0
seq_idx
=
0
while
seq_idx
<
len
(
proposal_lens
)
and
nonzero_proposal_len_idx_ptr
<
len
(
nonzero_proposal_len_indices
):
if
seq_idx
<
nonzero_proposal_len_indices
[
nonzero_proposal_len_idx_ptr
]:
# Sequence is not in the original nonzero_proposal_len_indices,
# meaning that it has a proposal length of 0 before sending to
# the draft worker.
assert
proposal_lens
[
seq_idx
]
==
0
new_proposal_lens
.
append
(
0
)
else
:
# Sequence is in the original nonzero_proposal_len_indices
if
maybe_sampler_output
[
nonzero_proposal_len_idx_ptr
]
is
None
:
# but does not have a proposal from the draft worker.
new_proposal_lens
.
append
(
0
)
else
:
# and has a proposal from the draft worker. Add it to the
# new nonzero proposal list and keep the sampler output.
new_proposal_lens
.
append
(
proposal_lens
[
seq_idx
])
new_nonzero_proposal_len_indices
.
append
(
seq_idx
)
new_maybe_sampler_output
.
append
(
maybe_sampler_output
[
nonzero_proposal_len_idx_ptr
])
nonzero_proposal_len_idx_ptr
+=
1
seq_idx
+=
1
# The remaining sequences should have proposal length of 0.
new_proposal_lens
.
extend
(
proposal_lens
[
seq_idx
:])
# We assume sampler_output will not be a list of all Nones.
# In this case this function should not be called.
assert
new_maybe_sampler_output
return
(
new_proposal_lens
,
new_maybe_sampler_output
,
new_nonzero_proposal_len_indices
)
def
_merge_outputs
(
self
,
batch_size
:
int
,
proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
List
[
SamplerOutput
]],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
retrieve_indices
=
self
.
tree_buffers
[
"retrieve_indices"
]
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
tensor
(
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
)
proposal_probs
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
,
self
.
_vocab_size
)
proposal_lens_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
len
(
proposal_lens
))
cart_candidates_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
batch_size
,
retrieve_indices
.
shape
[
0
],
retrieve_indices
.
shape
[
1
])
tree_attn_masks_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
_device
).
expand
(
batch_size
,
self
.
tree_buffers
[
"tree_attn_masks"
].
shape
[
0
],
self
.
tree_buffers
[
"tree_attn_masks"
].
shape
[
1
])
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
,
cart_candidates_tensor
,
tree_attn_masks_tensor
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
,
_
,
_
,
cart_candidates
,
tree_attn_masks
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
proposal_tokens
.
new_full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
None
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
,
)
proposal_lens_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_lens_tensor
[
nonzero_proposal_len_indices
]
=
proposal_len
entire_cart_candidates
=
cart_candidates
.
new_zeros
(
batch_size
,
*
cart_candidates
.
shape
[
1
:],
)
entire_cart_candidates
[
nonzero_proposal_len_indices
]
=
cart_candidates
entire_tree_attn_masks
=
tree_attn_masks
.
new_zeros
(
batch_size
,
*
tree_attn_masks
.
shape
[
1
:],
)
entire_tree_attn_masks
[
nonzero_proposal_len_indices
]
=
tree_attn_masks
entire_tree_attn_masks
=
entire_tree_attn_masks
.
reshape
(
-
1
,
tree_attn_masks
.
shape
[
-
1
])
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
,
entire_cart_candidates
,
entire_tree_attn_masks
vllm/spec_decode/util.py
View file @
19bc93d9
...
@@ -147,7 +147,8 @@ def split_batch_by_proposal_len(
...
@@ -147,7 +147,8 @@ def split_batch_by_proposal_len(
def
sampler_output_to_torch
(
def
sampler_output_to_torch
(
sampler_output_list
:
Sequence
[
SamplerOutput
],
sampler_transposed
:
bool
sampler_output_list
:
Sequence
[
SamplerOutput
],
sampler_transposed
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
"""Utility function which converts a list of SamplerOutput to tensors.
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
sampler_transposed here is used as the indicator for whether
...
@@ -162,6 +163,8 @@ def sampler_output_to_torch(
...
@@ -162,6 +163,8 @@ def sampler_output_to_torch(
"""
"""
# shape: [batch_size, num_sampler_output, vocab_size]
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_probs
=
None
if
sampler_output_list
[
0
].
sampled_token_probs
is
not
None
:
sampled_token_probs
=
torch
.
stack
(
sampled_token_probs
=
torch
.
stack
(
[
[
sampler_output
.
sampled_token_probs
sampler_output
.
sampled_token_probs
...
@@ -171,6 +174,8 @@ def sampler_output_to_torch(
...
@@ -171,6 +174,8 @@ def sampler_output_to_torch(
)
)
# shape: [batch_size, num_sampler_output, vocab_size]
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs
=
None
if
sampler_output_list
[
0
].
logprobs
is
not
None
:
sampled_token_logprobs
=
torch
.
stack
(
sampled_token_logprobs
=
torch
.
stack
(
[
sampler_output
.
logprobs
for
sampler_output
in
sampler_output_list
],
[
sampler_output
.
logprobs
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
dim
=
0
,
...
@@ -205,8 +210,33 @@ def sampler_output_to_torch(
...
@@ -205,8 +210,33 @@ def sampler_output_to_torch(
else
:
else
:
sampled_hidden_states
=
None
sampled_hidden_states
=
None
sampled_cart_candidates
=
None
if
sampler_output_list
[
0
].
cart_candidates
is
not
None
:
sampled_cart_candidates
=
torch
.
cat
(
[
sampler_output
.
cart_candidates
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_cart_candidates
=
sampled_cart_candidates
.
transpose
(
0
,
1
)
sampled_tree_attn_masks
=
None
if
sampler_output_list
[
0
].
tree_attn_masks
is
not
None
:
sampled_tree_attn_masks
=
torch
.
stack
(
[
sampler_output
.
tree_attn_masks
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_tree_attn_masks
=
sampled_tree_attn_masks
.
transpose
(
0
,
1
)
return
(
sampled_token_ids
,
sampled_token_probs
,
sampled_token_logprobs
,
return
(
sampled_token_ids
,
sampled_token_probs
,
sampled_token_logprobs
,
sampled_hidden_states
)
sampled_hidden_states
,
sampled_cart_candidates
,
sampled_tree_attn_masks
)
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
...
...
vllm/transformers_utils/configs/medusa.py
View file @
19bc93d9
import
os
import
os
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
,
List
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -15,6 +15,7 @@ class MedusaConfig(PretrainedConfig):
...
@@ -15,6 +15,7 @@ class MedusaConfig(PretrainedConfig):
max_paths
:
int
=
64
,
max_paths
:
int
=
64
,
topk
:
int
=
10
,
topk
:
int
=
10
,
truncated_vocab_size
:
Optional
[
int
]
=
None
,
truncated_vocab_size
:
Optional
[
int
]
=
None
,
medusa_choices
:
List
[
List
[
int
]]
=
None
,
**
kwargs
):
**
kwargs
):
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -26,6 +27,7 @@ class MedusaConfig(PretrainedConfig):
...
@@ -26,6 +27,7 @@ class MedusaConfig(PretrainedConfig):
self
.
max_seq_len
=
int
(
2
**
20
)
self
.
max_seq_len
=
int
(
2
**
20
)
self
.
truncated_vocab_size
=
vocab_size
if
truncated_vocab_size
is
None
\
self
.
truncated_vocab_size
=
vocab_size
if
truncated_vocab_size
is
None
\
else
truncated_vocab_size
else
truncated_vocab_size
self
.
medusa_choices
=
medusa_choices
if
"architectures"
not
in
kwargs
:
if
"architectures"
not
in
kwargs
:
kwargs
[
"architectures"
]
=
[
"MedusaModel"
]
kwargs
[
"architectures"
]
=
[
"MedusaModel"
]
...
@@ -51,6 +53,14 @@ class MedusaConfig(PretrainedConfig):
...
@@ -51,6 +53,14 @@ class MedusaConfig(PretrainedConfig):
def
num_attention_heads
(
self
):
def
num_attention_heads
(
self
):
return
0
return
0
@
property
def
num_lookahead_heads
(
self
):
return
self
.
num_heads
@
num_lookahead_heads
.
setter
def
num_lookahead_heads
(
self
,
num_lookahead_heads
:
int
):
self
.
num_heads
=
num_lookahead_heads
@
property
@
property
def
num_lookahead_tokens
(
self
):
def
num_lookahead_tokens
(
self
):
return
self
.
num_heads
return
self
.
num_heads
...
...
vllm/worker/cpu_worker.py
View file @
19bc93d9
...
@@ -13,6 +13,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
...
@@ -13,6 +13,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
...
@@ -309,6 +310,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -309,6 +310,10 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
return
self
.
cpu_cache
return
self
.
cpu_cache
@
property
def
cache_engines
(
self
)
->
Optional
[
List
[
CacheEngine
]]:
return
None
def
execute_worker
(
def
execute_worker
(
self
,
self
,
worker_input
:
WorkerInput
,
worker_input
:
WorkerInput
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment