Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
44e3ca68
"docs/serving/parallelism_scaling.md" did not exist on "c3649e4feeed30594f2de8f5183bd24b50b80f1c"
Commit
44e3ca68
authored
Dec 11, 2024
by
王敏
Browse files
[feat]优化medusa代码,通过VLLM_TREE_DECODING环境变量控制是否采用tree-style解码,计算逻辑主干隔离
parent
54b92ba4
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
330 additions
and
282 deletions
+330
-282
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+0
-10
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+0
-44
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+0
-10
vllm/attention/backends/tree_decoding_utils.py
vllm/attention/backends/tree_decoding_utils.py
+55
-0
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+3
-3
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+0
-44
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+208
-61
vllm/config.py
vllm/config.py
+2
-10
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-7
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+12
-10
vllm/envs.py
vllm/envs.py
+7
-0
vllm/lora/models.py
vllm/lora/models.py
+0
-2
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+6
-2
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+4
-1
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+13
-18
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+7
-5
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-35
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+12
-20
No files found.
vllm/attention/backends/pallas.py
View file @
44e3ca68
...
...
@@ -53,16 +53,6 @@ class PallasAttentionBackend(AttentionBackend):
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
v_cache
,
True
)
v_cache
[:,
dst_indices
]
=
v_cache
[:,
src_indices
]
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
@
dataclass
class
PallasMetadata
(
AttentionMetadata
):
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
44e3ca68
...
...
@@ -72,50 +72,6 @@ class ROCmFlashAttentionBackend(AttentionBackend):
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
key_caches
=
[]
value_caches
=
[]
num_layers
=
len
(
kv_caches
)
token_num
=
src_to_dists
.
shape
[
0
]
tmp_store_kv
=
torch
.
empty
(
(
2
,
num_layers
,
token_num
,
num_kv_heads
,
head_size
),
dtype
=
kv_caches
[
0
].
dtype
,
device
=
kv_caches
[
0
].
device
)
keys
=
tmp_store_kv
[
0
].
contiguous
()
values
=
tmp_store_kv
[
1
].
contiguous
()
for
kv_cache
in
kv_caches
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
num_kv_heads
,
head_size
)
key_caches
.
append
(
key_cache
)
value_caches
.
append
(
value_cache
)
ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
0
].
contiguous
(),
kv_cache_dtype
)
ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
1
].
contiguous
(),
kv_cache_dtype
)
@
dataclass
class
ROCmFlashAttentionMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
...
...
vllm/attention/backends/torch_sdpa.py
View file @
44e3ca68
...
...
@@ -65,16 +65,6 @@ class TorchSDPABackend(AttentionBackend):
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
@
dataclass
class
TorchSDPAMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
...
...
vllm/attention/backends/tree_decoding_utils.py
0 → 100644
View file @
44e3ca68
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Type
,
TypeVar
,
Union
,
Optional
import
torch
from
vllm.attention.backends.blocksparse_attn
import
BlocksparseFlashAttentionImpl
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.ops.paged_attn
import
PagedAttention
def
move_cache
(
backend
,
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
if
backend
.
get_name
()
==
"rocm-flash-attn"
or
\
backend
.
get_name
()
==
"xformers"
:
key_caches
=
[]
value_caches
=
[]
num_layers
=
len
(
kv_caches
)
token_num
=
src_to_dists
.
shape
[
0
]
tmp_store_kv
=
torch
.
empty
(
(
2
,
num_layers
,
token_num
,
num_kv_heads
,
head_size
),
dtype
=
kv_caches
[
0
].
dtype
,
device
=
kv_caches
[
0
].
device
)
keys
=
tmp_store_kv
[
0
].
contiguous
()
values
=
tmp_store_kv
[
1
].
contiguous
()
for
kv_cache
in
kv_caches
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
num_kv_heads
,
head_size
)
key_caches
.
append
(
key_cache
)
value_caches
.
append
(
value_cache
)
ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
0
].
contiguous
(),
kv_cache_dtype
)
ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
1
].
contiguous
(),
kv_cache_dtype
)
else
:
raise
NotImplementedError
(
"Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!"
)
\ No newline at end of file
vllm/attention/backends/utils.py
View file @
44e3ca68
...
...
@@ -9,6 +9,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState
)
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
ModelRunnerBase
...
...
@@ -188,8 +190,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self
.
block_size
,
inter_data
.
block_tables
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
,
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
):
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
...
...
@@ -272,7 +273,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
tree_attention_masks_tensor
=
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables
)
...
...
vllm/attention/backends/xformers.py
View file @
44e3ca68
...
...
@@ -68,50 +68,6 @@ class XFormersBackend(AttentionBackend):
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
key_caches
=
[]
value_caches
=
[]
num_layers
=
len
(
kv_caches
)
token_num
=
src_to_dists
.
shape
[
0
]
tmp_store_kv
=
torch
.
empty
(
(
2
,
num_layers
,
token_num
,
num_kv_heads
,
head_size
),
dtype
=
kv_caches
[
0
].
dtype
,
device
=
kv_caches
[
0
].
device
)
keys
=
tmp_store_kv
[
0
].
contiguous
()
values
=
tmp_store_kv
[
1
].
contiguous
()
for
kv_cache
in
kv_caches
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
num_kv_heads
,
head_size
)
key_caches
.
append
(
key_cache
)
value_caches
.
append
(
value_cache
)
ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
0
].
contiguous
(),
kv_cache_dtype
)
ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
1
].
contiguous
(),
kv_cache_dtype
)
@
dataclass
class
XFormersMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
...
...
vllm/attention/ops/paged_attn.py
View file @
44e3ca68
...
...
@@ -142,7 +142,102 @@ class PagedAttention:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
ops
.
paged_attention_v1_opt_tc
(
if
attn_masks
is
None
:
ops
.
paged_attention_v1_opt_tc
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v1_opt_tc_with_mask
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
if
attn_masks
is
None
:
ops
.
paged_attention_v1_opt
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v1_opt_with_mask
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
if
attn_masks
is
None
:
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
...
...
@@ -161,12 +256,10 @@ class PagedAttention:
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v1_
opt
(
ops
.
paged_attention_v1_
with_mask
(
output
,
query
,
key_cache
,
...
...
@@ -189,30 +282,6 @@ class PagedAttention:
attn_masks
,
attn_masks_stride
)
else
:
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
...
...
@@ -236,7 +305,114 @@ class PagedAttention:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
ops
.
paged_attention_v2_opt_tc
(
if
attn_masks
is
None
:
ops
.
paged_attention_v2_opt_tc
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v2_opt_tc_with_mask
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
if
attn_masks
is
None
:
ops
.
paged_attention_v2_opt
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v2_opt_with_mask
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
if
attn_masks
is
None
:
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
...
...
@@ -258,12 +434,10 @@ class PagedAttention:
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v2_
opt
(
ops
.
paged_attention_v2_
with_mask
(
output
,
exp_sums
,
max_logits
,
...
...
@@ -289,33 +463,6 @@ class PagedAttention:
attn_masks
,
attn_masks_stride
)
else
:
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
return
output
@
staticmethod
...
...
vllm/config.py
View file @
44e3ca68
...
...
@@ -1130,7 +1130,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha
:
Optional
[
float
],
disable_logprobs
:
Optional
[
bool
],
num_speculative_heads
:
Optional
[
int
],
tree_style_spec_decoding
:
Optional
[
bool
]
=
None
,
)
->
Optional
[
"SpeculativeConfig"
]:
"""Create a SpeculativeConfig if possible, else return None.
...
...
@@ -1191,9 +1190,6 @@ class SpeculativeConfig:
num_speculative_heads (Optional[int]): It will be used in tree-style
speculative generation, representing how many heads the draft model
has.
tree_style_spec_decoding (Optional[bool]): If set to True,
tree-style generation will be activated. If not specified,
it defaults to False.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
...
...
@@ -1308,9 +1304,9 @@ class SpeculativeConfig:
"n_predict parameter."
)
if
typical_acceptance_sampler_posterior_threshold
is
None
:
typical_acceptance_sampler_posterior_threshold
=
0.
3
typical_acceptance_sampler_posterior_threshold
=
0.
09
if
typical_acceptance_sampler_posterior_alpha
is
None
:
typical_acceptance_sampler_posterior_alpha
=
0.
09
typical_acceptance_sampler_posterior_alpha
=
0.
3
if
disable_logprobs
is
None
:
disable_logprobs
=
True
...
...
@@ -1328,7 +1324,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
tree_style_spec_decoding
=
tree_style_spec_decoding
)
@
staticmethod
...
...
@@ -1423,7 +1418,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
tree_style_spec_decoding
:
bool
,
):
"""Create a SpeculativeConfig object.
...
...
@@ -1458,7 +1452,6 @@ class SpeculativeConfig:
returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
tree_style_spec_decoding: Whether to use tree-style generation.
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
...
...
@@ -1474,7 +1467,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha
self
.
disable_logprobs
=
disable_logprobs
self
.
disable_log_stats
=
disable_log_stats
self
.
tree_style_spec_decoding
=
tree_style_spec_decoding
self
.
_verify_args
()
...
...
vllm/engine/arg_utils.py
View file @
44e3ca68
...
...
@@ -176,7 +176,6 @@ class EngineArgs:
disable_logprobs_during_spec_decoding
:
Optional
[
bool
]
=
None
otlp_traces_endpoint
:
Optional
[
str
]
=
None
tree_style_spec_decoding
:
Optional
[
bool
]
=
None
collect_detailed_traces
:
Optional
[
str
]
=
None
disable_async_output_proc
:
bool
=
False
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
...
...
@@ -707,11 +706,6 @@ class EngineArgs:
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.'
)
parser
.
add_argument
(
'--tree-style-spec-decoding'
,
type
=
bool
,
default
=
False
,
help
=
'If set to True, tree-style generation will be activated.'
)
parser
.
add_argument
(
'--typical-acceptance-sampler-posterior-threshold'
,
...
...
@@ -997,7 +991,6 @@ class EngineArgs:
typical_acceptance_sampler_posterior_alpha
=
self
.
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
self
.
disable_logprobs_during_spec_decoding
,
tree_style_spec_decoding
=
self
.
tree_style_spec_decoding
,
num_speculative_heads
=
self
.
num_speculative_heads
)
...
...
vllm/engine/llm_engine.py
View file @
44e3ca68
import
os
import
time
from
collections
import
deque
from
contextlib
import
contextmanager
...
...
@@ -463,6 +464,8 @@ class LLMEngine:
get_tokenizer_for_seq
,
),
))
self
.
tree_decoding
=
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
...
...
@@ -989,16 +992,15 @@ class LLMEngine:
output
=
[
outputs_by_sequence_group
[
0
][
i
]]
# tree style speculative decoding may generate empty output in first step
if
self
.
speculative_config
and
self
.
speculative_config
.
tree_style_spec_decoding
:
if
outputs
and
isinstance
(
output
[
0
],
CompletionSequenceGroupOutput
):
samples
=
[
o
.
samples
[
0
]
for
o
in
output
]
valid_samples
=
[
sample
for
sample
in
samples
if
sample
.
output_token
!=
VLLM_INVALID_TOKEN_ID
]
if
len
(
valid_samples
)
==
0
:
empty_seq_indices
.
append
(
i
)
continue
if
self
.
tree_decoding
and
outputs
and
isinstance
(
output
[
0
],
CompletionSequenceGroupOutput
):
samples
=
[
o
.
samples
[
0
]
for
o
in
output
]
valid_samples
=
[
sample
for
sample
in
samples
if
sample
.
output_token
!=
VLLM_INVALID_TOKEN_ID
]
if
len
(
valid_samples
)
==
0
:
empty_seq_indices
.
append
(
i
)
continue
if
not
is_async
:
seq_group
.
update_num_computed_tokens
(
...
...
vllm/envs.py
View file @
44e3ca68
...
...
@@ -68,6 +68,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_AWQ
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH
:
bool
=
False
VLLM_TREE_DECODING
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -453,6 +454,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_ALLOW_RUNTIME_LORA_UPDATING"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
)),
# If set, vLLM will use tree-style speculative decoding.
"VLLM_TREE_DECODING"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_TREE_DECODING"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
))
}
# end-env-vars-definition
...
...
vllm/lora/models.py
View file @
44e3ca68
...
...
@@ -117,8 +117,6 @@ class LoRAModel(AdapterModel):
pin_memory
=
str
(
device
)
==
"cpu"
and
is_pin_memory_available
()
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
{}
for
tensor_name
,
tensor
in
tensors
.
items
():
if
"lora_A"
not
in
tensor_name
and
"lora_B"
not
in
tensor_name
:
continue
module_name
,
is_lora_a
=
parse_fine_tuned_lora_name
(
tensor_name
)
if
module_name
not
in
loras
:
lora_embeddings_tensor
=
None
...
...
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
44e3ca68
import
os
from
typing
import
Optional
,
List
import
torch
import
torch.jit
...
...
@@ -39,6 +40,8 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self
.
_posterior_alpha
=
posterior_alpha
super
().
__init__
(
strict_mode
=
strict_mode
)
self
.
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
def
forward
(
self
,
target_with_bonus_probs
:
torch
.
Tensor
,
...
...
@@ -92,7 +95,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self
.
_raise_if_incorrect_input
(
target_with_bonus_probs
,
draft_token_ids
,
bonus_token_ids
)
if
cart_candidates
is
None
:
if
not
self
.
tree_decoding
:
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
draft_token_ids
)
...
...
@@ -101,6 +104,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
draft_token_ids
,
bonus_token_ids
)
else
:
assert
cart_candidates
is
not
None
target_probs
=
target_with_bonus_probs
output_token_ids
=
self
.
_evaluate_accepted_tokens_tree_style
(
target_probs
,
draft_token_ids
,
...
...
@@ -199,7 +203,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
output_token_id_list
=
[]
accept_length_list
=
accept_length
.
cpu
().
tolist
()
logger
.
info
(
"accept_length:%s"
,
accept_length_list
)
#
logger.info("accept_length:%s", accept_length_list)
for
i
in
range
(
batch_size
):
output_best_candidates
.
append
(
best_candidate
[
i
])
accept_lengths
.
append
(
accept_length_list
[
i
])
...
...
vllm/spec_decode/medusa_worker.py
View file @
44e3ca68
import
os
import
weakref
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Dict
...
...
@@ -29,6 +30,8 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# Lazy initialization list.
self
.
_proposer
:
SpeculativeProposer
self
.
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
def
init_device
(
self
):
super
().
init_device
()
...
...
@@ -38,7 +41,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# get medusa choices and generate medusa_buffers
self
.
medusa_buffers
=
None
if
hasattr
(
self
.
model_runner
.
model
,
'medusa_choices'
):
if
self
.
tree_decoding
and
hasattr
(
self
.
model_runner
.
model
,
'medusa_choices'
):
self
.
medusa_choices
=
self
.
model_runner
.
model
.
medusa_choices
if
self
.
medusa_choices
is
not
None
:
self
.
medusa_buffers
=
self
.
generate_medusa_buffers
(
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
44e3ca68
import
os
from
collections
import
defaultdict
from
functools
import
cached_property
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
...
...
@@ -82,9 +83,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_alpha
=
speculative_config
.
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
speculative_config
.
disable_logprobs
,
disable_log_stats
=
speculative_config
.
disable_log_stats
,
tree_style_spec_decoding
=
speculative_config
.
tree_style_spec_decoding
,
)
disable_log_stats
=
speculative_config
.
disable_log_stats
)
return
spec_decode_worker
...
...
@@ -126,7 +125,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
tree_style_spec_decoding
:
bool
,
)
->
"SpecDecodeWorker"
:
allow_zero_draft_token_step
=
True
...
...
@@ -191,8 +189,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_log_stats
=
disable_log_stats
,
disable_by_batch_size
=
disable_by_batch_size
,
spec_decode_sampler
=
spec_decode_sampler
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
,
tree_style_spec_decoding
=
tree_style_spec_decoding
)
allow_zero_draft_token_step
=
allow_zero_draft_token_step
)
def
__init__
(
self
,
...
...
@@ -204,7 +201,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
tree_style_spec_decoding
:
bool
=
False
,
):
"""
Create a SpecDecodeWorker.
...
...
@@ -233,7 +229,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
tree_style_spec_decoding: Whether to use tree-style generation.
"""
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
...
...
@@ -268,7 +263,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
_disable_logprobs
=
disable_logprobs
self
.
_disable_log_stats
=
disable_log_stats
self
.
tree_
style_spec_decoding
=
tree_style_spec_decoding
self
.
tree_
decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
def
init_device
(
self
)
->
None
:
"""Initialize both scorer and proposer models.
...
...
@@ -285,7 +280,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
spec_decode_sampler
.
init_gpu_tensors
(
self
.
rank
)
if
not
self
.
tree_
style_spec_
decoding
:
if
not
self
.
tree_decoding
:
self
.
scorer
=
BatchExpansionTop1Scorer
(
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
...
...
@@ -324,7 +319,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
)
=
True
# tree_style decoding modify probs in _verify_tokens
if
not
self
.
tree_
style_spec_
decoding
:
if
not
self
.
tree_decoding
:
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
)
=
True
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
...
...
@@ -535,7 +530,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
if
self
.
tree_
style_spec_
decoding
and
self
.
kvcache_slot_to_be_moved
is
not
None
:
if
self
.
tree_decoding
and
self
.
kvcache_slot_to_be_moved
is
not
None
:
execute_model_req
.
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
self
.
kvcache_slot_to_be_moved
=
None
...
...
@@ -560,7 +555,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
hidden_states
,
execute_model_req
.
seq_group_metadata_list
)
# Store logits from target model execution.
if
self
.
tree_
style_spec_
decoding
:
if
self
.
tree_decoding
:
logits
=
sampler_output
.
logits
if
logits
is
not
None
:
if
self
.
previous_logits
is
None
:
...
...
@@ -612,7 +607,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
scorer_worker
.
execute_model
()
if
not
data
[
"disable_all_speculation"
]:
# if not self.tree_
style_spec_
decoding:
# if not self.tree_decoding:
# # Even if num_lookahead_slots is zero, we want to run the
# # proposer model as it may have KV.
# #
...
...
@@ -677,7 +672,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"workers generate no tokens"
)
# Pass tree attention mask and postions to target model
if
self
.
tree_
style_spec_
decoding
:
if
self
.
tree_decoding
:
execute_model_req
.
tree_attn_masks
=
proposals
.
tree_attn_masks
execute_model_req
.
tree_position_ids
=
proposals
.
tree_position_ids
...
...
@@ -695,7 +690,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals
,
execute_model_req
.
num_lookahead_slots
)
# move kv_caches of selected tokens to right positions
if
self
.
tree_
style_spec_
decoding
:
if
self
.
tree_decoding
:
self
.
move_caches
(
execute_model_req
,
select_indices_list
,
accept_lengths
)
stage_times
=
(
proposal_timer
.
elapsed_time_ms
/
num_lookahead_slots
,
...
...
@@ -739,7 +734,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
else
:
proposal_verifier_probs
=
proposal_scores
.
probs
if
self
.
tree_
style_spec_
decoding
:
if
self
.
tree_decoding
:
retrieve_indices
=
proposals
.
retrieve_indices
proposal_verifier_probs
=
proposal_verifier_probs
[:,
retrieve_indices
]
...
...
@@ -797,7 +792,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
if
not
self
.
tree_
style_spec_
decoding
:
if
not
self
.
tree_decoding
:
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
+
1
).
clone
()
else
:
...
...
vllm/worker/cache_engine.py
View file @
44e3ca68
...
...
@@ -8,6 +8,7 @@ from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
,
is_pin_memory_available
)
from
vllm.attention.backends.tree_decoding_utils
import
move_cache
logger
=
init_logger
(
__name__
)
...
...
@@ -103,11 +104,12 @@ class CacheEngine:
def
move_caches
(
self
,
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dsts
:
torch
.
Tensor
)
->
None
:
self
.
attn_backend
.
move_cache
(
kv_caches
,
src_to_dsts
,
self
.
cache_config
.
cache_dtype
,
self
.
num_kv_heads
,
self
.
head_size
)
move_cache
(
self
.
attn_backend
,
kv_caches
,
src_to_dsts
,
self
.
cache_config
.
cache_dtype
,
self
.
num_kv_heads
,
self
.
head_size
)
@
staticmethod
def
get_cache_block_size
(
...
...
vllm/worker/model_runner.py
View file @
44e3ca68
...
...
@@ -198,7 +198,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
lora_requests
.
clear
()
# type: ignore
self
.
prompt_adapter_index_mapping
.
clear
()
# type: ignore
self
.
prompt_adapter_prompt_mapping
.
clear
()
# type: ignore
self
.
tree_attn_masks
[
0
]
=
None
# type: ignore
def
__init__
(
self
,
...
...
@@ -246,9 +245,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
reinit
:
bool
=
False
,
reinit_use_defaults
:
bool
=
False
,
encoder_seq_len
:
int
=
0
,
# attention mask used in tree-style generation
tree_attn_masks
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
):
if
reinit
:
assert
len
(
self
.
seq_ids
)
==
len
(
seq_ids
)
# type: ignore
...
...
@@ -339,12 +335,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_prompt_mapping
else
:
self
.
prompt_adapter_prompt_mapping
.
clear
()
if
tree_attn_masks
:
self
.
tree_attn_masks
=
tree_attn_masks
else
:
self
.
tree_attn_masks
.
clear
()
else
:
self
.
input_tokens
=
input_tokens
or
[]
self
.
input_positions
=
input_positions
or
[]
...
...
@@ -364,7 +354,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_index_mapping
or
[])
self
.
prompt_adapter_prompt_mapping
=
(
prompt_adapter_prompt_mapping
or
[])
self
.
tree_attn_masks
=
tree_attn_masks
or
[]
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
multi_modal_inputs
=
multi_modal_inputs
...
...
@@ -380,7 +369,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
input_tokens
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
input_positions
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
tree_attn_masks
=
[
None
for
_
in
range
(
self
.
n_seqs
)]
self
.
mrope_input_positions
=
None
self
.
seq_lens
=
[
0
]
*
self
.
n_seqs
self
.
orig_seq_lens
=
[
0
]
*
self
.
n_seqs
...
...
@@ -469,13 +457,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
sliding_window
+
self
.
block_size
-
1
)
//
self
.
block_size
self
.
block_aligned_sliding_window
=
\
self
.
sliding_window_blocks
*
self
.
block_size
if
hasattr
(
self
.
runner
,
"tree_attn_masks"
):
self
.
tree_attn_masks
=
self
.
runner
.
tree_attn_masks
self
.
tree_position_ids
=
self
.
runner
.
tree_position_ids
else
:
self
.
tree_attn_masks
=
None
self
.
tree_position_ids
=
None
self
.
is_encoder_decoder_model
=
self
.
runner
.
model_config
.
is_encoder_decoder_model
...
...
@@ -853,16 +834,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if
cuda_graph_pad_size
:
seq_lens
.
extend
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
# prepare tree attention masks
tree_attention_masks_tensor
=
self
.
tree_attn_masks
if
tree_attention_masks_tensor
is
not
None
:
tree_attention_masks_tensor
=
tree_attention_masks_tensor
.
contiguous
()
input_positions_tensor
=
self
.
tree_position_ids
.
contiguous
()
# Attention metadata.
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
,
tree_attention_masks_tensor
=
tree_attention_masks_tensor
)
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
)
# LoRA data.
lora_requests
=
set
()
...
...
@@ -1033,9 +1007,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
inter_data_cache
:
Dict
[
int
,
PyObjectCache
]
=
{}
self
.
sampling_metadata_cache
:
SamplingMetadataCache
=
\
SamplingMetadataCache
()
self
.
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
self
.
tree_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
...
...
@@ -1503,11 +1474,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
@
property
def
vocab_size
(
self
)
->
int
:
return
self
.
model_config
.
get_vocab_size
()
def
set_tree_style_args
(
self
,
tree_attn_masks
:
Optional
[
torch
.
Tensor
],
tree_position_ids
:
Optional
[
torch
.
Tensor
]):
self
.
tree_attn_masks
=
tree_attn_masks
self
.
tree_position_ids
=
tree_position_ids
class
ModelRunner
(
GPUModelRunnerBase
[
ModelInputForGPUWithSamplingMetadata
]):
...
...
vllm/worker/worker_base.py
View file @
44e3ca68
...
...
@@ -31,6 +31,7 @@ class WorkerBase(ABC):
"""
model_input
:
Optional
[
ModelRunnerInputBase
]
=
None
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
@
abstractmethod
def
init_device
(
self
)
->
None
:
...
...
@@ -103,18 +104,6 @@ class WorkerBase(ABC):
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
@
property
@
abstractmethod
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
"""
Gets the list of kv caches to pass to the worker's model runner. Each
element in the list is a kv cache corresponding to a particular virtual
engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
"""
raise
NotImplementedError
@
property
@
abstractmethod
def
cache_engines
(
self
)
->
Optional
[
List
[
CacheEngine
]]:
...
...
@@ -138,10 +127,6 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def
list_loras
(
self
)
->
Set
[
int
]:
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
@
property
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
return
None
@
property
def
cache_engines
(
self
)
->
Optional
[
List
[
CacheEngine
]]:
...
...
@@ -282,10 +267,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
if
hasattr
(
self
.
model_runner
,
"set_tree_style_args"
):
self
.
model_runner
.
set_tree_style_args
(
tree_attn_masks
=
execute_model_req
.
tree_attn_masks
,
tree_position_ids
=
execute_model_req
.
tree_position_ids
)
model_input
:
ModelRunnerInputBase
=
(
self
.
model_runner
.
prepare_model_input
(
...
...
@@ -293,6 +274,17 @@ class LocalOrDistributedWorkerBase(WorkerBase):
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
if
self
.
tree_decoding
and
execute_model_req
.
tree_position_ids
is
not
None
and
\
execute_model_req
.
tree_attn_masks
is
not
None
:
if
hasattr
(
model_input
,
"input_positions"
)
and
\
hasattr
(
model_input
,
"attn_metadata"
)
and
\
hasattr
(
model_input
.
attn_metadata
,
"tree_attention_masks_tensor"
):
attn_metadata
=
model_input
.
attn_metadata
attn_metadata
.
tree_attention_masks_tensor
=
execute_model_req
.
tree_attn_masks
.
contiguous
()
model_input
=
dataclasses
.
replace
(
model_input
,
input_positions
=
execute_model_req
.
tree_position_ids
.
contiguous
(),
attn_metadata
=
attn_metadata
)
kwargs
=
extract_previous_hidden_states
(
execute_model_req
)
if
self
.
do_metadata_broadcast
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment