Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
44e3ca68
Commit
44e3ca68
authored
Dec 11, 2024
by
王敏
Browse files
[feat]优化medusa代码,通过VLLM_TREE_DECODING环境变量控制是否采用tree-style解码,计算逻辑主干隔离
parent
54b92ba4
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
330 additions
and
282 deletions
+330
-282
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+0
-10
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+0
-44
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+0
-10
vllm/attention/backends/tree_decoding_utils.py
vllm/attention/backends/tree_decoding_utils.py
+55
-0
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+3
-3
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+0
-44
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+208
-61
vllm/config.py
vllm/config.py
+2
-10
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-7
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+12
-10
vllm/envs.py
vllm/envs.py
+7
-0
vllm/lora/models.py
vllm/lora/models.py
+0
-2
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+6
-2
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+4
-1
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+13
-18
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+7
-5
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-35
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+12
-20
No files found.
vllm/attention/backends/pallas.py
View file @
44e3ca68
...
@@ -53,16 +53,6 @@ class PallasAttentionBackend(AttentionBackend):
...
@@ -53,16 +53,6 @@ class PallasAttentionBackend(AttentionBackend):
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
v_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
v_cache
,
True
)
v_cache
[:,
dst_indices
]
=
v_cache
[:,
src_indices
]
v_cache
[:,
dst_indices
]
=
v_cache
[:,
src_indices
]
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
@
dataclass
@
dataclass
class
PallasMetadata
(
AttentionMetadata
):
class
PallasMetadata
(
AttentionMetadata
):
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
44e3ca68
...
@@ -72,50 +72,6 @@ class ROCmFlashAttentionBackend(AttentionBackend):
...
@@ -72,50 +72,6 @@ class ROCmFlashAttentionBackend(AttentionBackend):
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
key_caches
=
[]
value_caches
=
[]
num_layers
=
len
(
kv_caches
)
token_num
=
src_to_dists
.
shape
[
0
]
tmp_store_kv
=
torch
.
empty
(
(
2
,
num_layers
,
token_num
,
num_kv_heads
,
head_size
),
dtype
=
kv_caches
[
0
].
dtype
,
device
=
kv_caches
[
0
].
device
)
keys
=
tmp_store_kv
[
0
].
contiguous
()
values
=
tmp_store_kv
[
1
].
contiguous
()
for
kv_cache
in
kv_caches
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
num_kv_heads
,
head_size
)
key_caches
.
append
(
key_cache
)
value_caches
.
append
(
value_cache
)
ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
0
].
contiguous
(),
kv_cache_dtype
)
ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
1
].
contiguous
(),
kv_cache_dtype
)
@
dataclass
@
dataclass
class
ROCmFlashAttentionMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
class
ROCmFlashAttentionMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
...
...
vllm/attention/backends/torch_sdpa.py
View file @
44e3ca68
...
@@ -65,16 +65,6 @@ class TorchSDPABackend(AttentionBackend):
...
@@ -65,16 +65,6 @@ class TorchSDPABackend(AttentionBackend):
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
NotImplementedError
@
dataclass
@
dataclass
class
TorchSDPAMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
class
TorchSDPAMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
...
...
vllm/attention/backends/tree_decoding_utils.py
0 → 100644
View file @
44e3ca68
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Type
,
TypeVar
,
Union
,
Optional
import
torch
from
vllm.attention.backends.blocksparse_attn
import
BlocksparseFlashAttentionImpl
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.ops.paged_attn
import
PagedAttention
def
move_cache
(
backend
,
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
if
backend
.
get_name
()
==
"rocm-flash-attn"
or
\
backend
.
get_name
()
==
"xformers"
:
key_caches
=
[]
value_caches
=
[]
num_layers
=
len
(
kv_caches
)
token_num
=
src_to_dists
.
shape
[
0
]
tmp_store_kv
=
torch
.
empty
(
(
2
,
num_layers
,
token_num
,
num_kv_heads
,
head_size
),
dtype
=
kv_caches
[
0
].
dtype
,
device
=
kv_caches
[
0
].
device
)
keys
=
tmp_store_kv
[
0
].
contiguous
()
values
=
tmp_store_kv
[
1
].
contiguous
()
for
kv_cache
in
kv_caches
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
num_kv_heads
,
head_size
)
key_caches
.
append
(
key_cache
)
value_caches
.
append
(
value_cache
)
ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
0
].
contiguous
(),
kv_cache_dtype
)
ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
1
].
contiguous
(),
kv_cache_dtype
)
else
:
raise
NotImplementedError
(
"Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!"
)
\ No newline at end of file
vllm/attention/backends/utils.py
View file @
44e3ca68
...
@@ -9,6 +9,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
...
@@ -9,6 +9,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState
)
AttentionState
)
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
ModelRunnerBase
from
vllm.worker.model_runner_base
import
ModelRunnerBase
...
@@ -188,8 +190,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -188,8 +190,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self
.
block_size
,
inter_data
.
block_tables
)
self
.
block_size
,
inter_data
.
block_tables
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
,
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
tree_attention_masks_tensor
:
Optional
[
torch
.
Tensor
]
=
None
):
"""Build attention metadata with on-device tensors.
"""Build attention metadata with on-device tensors.
Args:
Args:
...
@@ -272,7 +273,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -272,7 +273,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor
=
context_lens_tensor
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
,
tree_attention_masks_tensor
=
tree_attention_masks_tensor
,
block_tables_list
=
self
.
block_tables
block_tables_list
=
self
.
block_tables
)
)
...
...
vllm/attention/backends/xformers.py
View file @
44e3ca68
...
@@ -68,50 +68,6 @@ class XFormersBackend(AttentionBackend):
...
@@ -68,50 +68,6 @@ class XFormersBackend(AttentionBackend):
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
move_cache
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
None
:
key_caches
=
[]
value_caches
=
[]
num_layers
=
len
(
kv_caches
)
token_num
=
src_to_dists
.
shape
[
0
]
tmp_store_kv
=
torch
.
empty
(
(
2
,
num_layers
,
token_num
,
num_kv_heads
,
head_size
),
dtype
=
kv_caches
[
0
].
dtype
,
device
=
kv_caches
[
0
].
device
)
keys
=
tmp_store_kv
[
0
].
contiguous
()
values
=
tmp_store_kv
[
1
].
contiguous
()
for
kv_cache
in
kv_caches
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
num_kv_heads
,
head_size
)
key_caches
.
append
(
key_cache
)
value_caches
.
append
(
value_cache
)
ops
.
read_cache
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
0
].
contiguous
(),
kv_cache_dtype
)
ops
.
write_cache_multi_layers
(
keys
,
values
,
key_caches
,
value_caches
,
src_to_dists
[:,
1
].
contiguous
(),
kv_cache_dtype
)
@
dataclass
@
dataclass
class
XFormersMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
class
XFormersMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
...
...
vllm/attention/ops/paged_attn.py
View file @
44e3ca68
...
@@ -142,7 +142,102 @@ class PagedAttention:
...
@@ -142,7 +142,102 @@ class PagedAttention:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
ops
.
paged_attention_v1_opt_tc
(
if
attn_masks
is
None
:
ops
.
paged_attention_v1_opt_tc
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v1_opt_tc_with_mask
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
if
attn_masks
is
None
:
ops
.
paged_attention_v1_opt
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v1_opt_with_mask
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
if
attn_masks
is
None
:
ops
.
paged_attention_v1
(
output
,
output
,
query
,
query
,
key_cache
,
key_cache
,
...
@@ -161,12 +256,10 @@ class PagedAttention:
...
@@ -161,12 +256,10 @@ class PagedAttention:
blocksparse_local_blocks
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_head_sliding_step
attn_masks
,
attn_masks_stride
)
)
else
:
else
:
ops
.
paged_attention_v1_
opt
(
ops
.
paged_attention_v1_
with_mask
(
output
,
output
,
query
,
query
,
key_cache
,
key_cache
,
...
@@ -189,30 +282,6 @@ class PagedAttention:
...
@@ -189,30 +282,6 @@ class PagedAttention:
attn_masks
,
attn_masks
,
attn_masks_stride
attn_masks_stride
)
)
else
:
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
else
:
# Run PagedAttention V2.
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
assert
_PARTITION_SIZE
%
block_size
==
0
...
@@ -236,7 +305,114 @@ class PagedAttention:
...
@@ -236,7 +305,114 @@ class PagedAttention:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_OPT_OP
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
if
envs
.
VLLM_USE_TC_PAGED_ATTN
:
ops
.
paged_attention_v2_opt_tc
(
if
attn_masks
is
None
:
ops
.
paged_attention_v2_opt_tc
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v2_opt_tc_with_mask
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
if
attn_masks
is
None
:
ops
.
paged_attention_v2_opt
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
else
:
ops
.
paged_attention_v2_opt_with_mask
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
else
:
if
attn_masks
is
None
:
ops
.
paged_attention_v2
(
output
,
output
,
exp_sums
,
exp_sums
,
max_logits
,
max_logits
,
...
@@ -258,12 +434,10 @@ class PagedAttention:
...
@@ -258,12 +434,10 @@ class PagedAttention:
blocksparse_local_blocks
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
blocksparse_head_sliding_step
attn_masks
,
attn_masks_stride
)
)
else
:
else
:
ops
.
paged_attention_v2_
opt
(
ops
.
paged_attention_v2_
with_mask
(
output
,
output
,
exp_sums
,
exp_sums
,
max_logits
,
max_logits
,
...
@@ -289,33 +463,6 @@ class PagedAttention:
...
@@ -289,33 +463,6 @@ class PagedAttention:
attn_masks
,
attn_masks
,
attn_masks_stride
attn_masks_stride
)
)
else
:
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
attn_masks
,
attn_masks_stride
)
return
output
return
output
@
staticmethod
@
staticmethod
...
...
vllm/config.py
View file @
44e3ca68
...
@@ -1130,7 +1130,6 @@ class SpeculativeConfig:
...
@@ -1130,7 +1130,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha
:
Optional
[
float
],
typical_acceptance_sampler_posterior_alpha
:
Optional
[
float
],
disable_logprobs
:
Optional
[
bool
],
disable_logprobs
:
Optional
[
bool
],
num_speculative_heads
:
Optional
[
int
],
num_speculative_heads
:
Optional
[
int
],
tree_style_spec_decoding
:
Optional
[
bool
]
=
None
,
)
->
Optional
[
"SpeculativeConfig"
]:
)
->
Optional
[
"SpeculativeConfig"
]:
"""Create a SpeculativeConfig if possible, else return None.
"""Create a SpeculativeConfig if possible, else return None.
...
@@ -1191,9 +1190,6 @@ class SpeculativeConfig:
...
@@ -1191,9 +1190,6 @@ class SpeculativeConfig:
num_speculative_heads (Optional[int]): It will be used in tree-style
num_speculative_heads (Optional[int]): It will be used in tree-style
speculative generation, representing how many heads the draft model
speculative generation, representing how many heads the draft model
has.
has.
tree_style_spec_decoding (Optional[bool]): If set to True,
tree-style generation will be activated. If not specified,
it defaults to False.
Returns:
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
...
@@ -1308,9 +1304,9 @@ class SpeculativeConfig:
...
@@ -1308,9 +1304,9 @@ class SpeculativeConfig:
"n_predict parameter."
)
"n_predict parameter."
)
if
typical_acceptance_sampler_posterior_threshold
is
None
:
if
typical_acceptance_sampler_posterior_threshold
is
None
:
typical_acceptance_sampler_posterior_threshold
=
0.
3
typical_acceptance_sampler_posterior_threshold
=
0.
09
if
typical_acceptance_sampler_posterior_alpha
is
None
:
if
typical_acceptance_sampler_posterior_alpha
is
None
:
typical_acceptance_sampler_posterior_alpha
=
0.
09
typical_acceptance_sampler_posterior_alpha
=
0.
3
if
disable_logprobs
is
None
:
if
disable_logprobs
is
None
:
disable_logprobs
=
True
disable_logprobs
=
True
...
@@ -1328,7 +1324,6 @@ class SpeculativeConfig:
...
@@ -1328,7 +1324,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha
,
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
disable_logprobs
,
disable_logprobs
=
disable_logprobs
,
disable_log_stats
=
disable_log_stats
,
disable_log_stats
=
disable_log_stats
,
tree_style_spec_decoding
=
tree_style_spec_decoding
)
)
@
staticmethod
@
staticmethod
...
@@ -1423,7 +1418,6 @@ class SpeculativeConfig:
...
@@ -1423,7 +1418,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
disable_log_stats
:
bool
,
tree_style_spec_decoding
:
bool
,
):
):
"""Create a SpeculativeConfig object.
"""Create a SpeculativeConfig object.
...
@@ -1458,7 +1452,6 @@ class SpeculativeConfig:
...
@@ -1458,7 +1452,6 @@ class SpeculativeConfig:
returned.
returned.
disable_log_stats: Whether to disable periodic printing of stage
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
times in speculative decoding.
tree_style_spec_decoding: Whether to use tree-style generation.
"""
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
draft_parallel_config
=
draft_parallel_config
...
@@ -1474,7 +1467,6 @@ class SpeculativeConfig:
...
@@ -1474,7 +1467,6 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha
typical_acceptance_sampler_posterior_alpha
self
.
disable_logprobs
=
disable_logprobs
self
.
disable_logprobs
=
disable_logprobs
self
.
disable_log_stats
=
disable_log_stats
self
.
disable_log_stats
=
disable_log_stats
self
.
tree_style_spec_decoding
=
tree_style_spec_decoding
self
.
_verify_args
()
self
.
_verify_args
()
...
...
vllm/engine/arg_utils.py
View file @
44e3ca68
...
@@ -176,7 +176,6 @@ class EngineArgs:
...
@@ -176,7 +176,6 @@ class EngineArgs:
disable_logprobs_during_spec_decoding
:
Optional
[
bool
]
=
None
disable_logprobs_during_spec_decoding
:
Optional
[
bool
]
=
None
otlp_traces_endpoint
:
Optional
[
str
]
=
None
otlp_traces_endpoint
:
Optional
[
str
]
=
None
tree_style_spec_decoding
:
Optional
[
bool
]
=
None
collect_detailed_traces
:
Optional
[
str
]
=
None
collect_detailed_traces
:
Optional
[
str
]
=
None
disable_async_output_proc
:
bool
=
False
disable_async_output_proc
:
bool
=
False
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
...
@@ -707,11 +706,6 @@ class EngineArgs:
...
@@ -707,11 +706,6 @@ class EngineArgs:
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.'
)
'and vice versa.'
)
parser
.
add_argument
(
'--tree-style-spec-decoding'
,
type
=
bool
,
default
=
False
,
help
=
'If set to True, tree-style generation will be activated.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--typical-acceptance-sampler-posterior-threshold'
,
'--typical-acceptance-sampler-posterior-threshold'
,
...
@@ -997,7 +991,6 @@ class EngineArgs:
...
@@ -997,7 +991,6 @@ class EngineArgs:
typical_acceptance_sampler_posterior_alpha
=
self
.
typical_acceptance_sampler_posterior_alpha
=
self
.
typical_acceptance_sampler_posterior_alpha
,
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
self
.
disable_logprobs_during_spec_decoding
,
disable_logprobs
=
self
.
disable_logprobs_during_spec_decoding
,
tree_style_spec_decoding
=
self
.
tree_style_spec_decoding
,
num_speculative_heads
=
self
.
num_speculative_heads
num_speculative_heads
=
self
.
num_speculative_heads
)
)
...
...
vllm/engine/llm_engine.py
View file @
44e3ca68
import
os
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -463,6 +464,8 @@ class LLMEngine:
...
@@ -463,6 +464,8 @@ class LLMEngine:
get_tokenizer_for_seq
,
get_tokenizer_for_seq
,
),
),
))
))
self
.
tree_decoding
=
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
@@ -989,16 +992,15 @@ class LLMEngine:
...
@@ -989,16 +992,15 @@ class LLMEngine:
output
=
[
outputs_by_sequence_group
[
0
][
i
]]
output
=
[
outputs_by_sequence_group
[
0
][
i
]]
# tree style speculative decoding may generate empty output in first step
# tree style speculative decoding may generate empty output in first step
if
self
.
speculative_config
and
self
.
speculative_config
.
tree_style_spec_decoding
:
if
self
.
tree_decoding
and
outputs
and
isinstance
(
output
[
0
],
CompletionSequenceGroupOutput
):
if
outputs
and
isinstance
(
output
[
0
],
CompletionSequenceGroupOutput
):
samples
=
[
o
.
samples
[
0
]
for
o
in
output
]
samples
=
[
o
.
samples
[
0
]
for
o
in
output
]
valid_samples
=
[
valid_samples
=
[
sample
for
sample
in
samples
sample
for
sample
in
samples
if
sample
.
output_token
!=
VLLM_INVALID_TOKEN_ID
if
sample
.
output_token
!=
VLLM_INVALID_TOKEN_ID
]
]
if
len
(
valid_samples
)
==
0
:
if
len
(
valid_samples
)
==
0
:
empty_seq_indices
.
append
(
i
)
empty_seq_indices
.
append
(
i
)
continue
continue
if
not
is_async
:
if
not
is_async
:
seq_group
.
update_num_computed_tokens
(
seq_group
.
update_num_computed_tokens
(
...
...
vllm/envs.py
View file @
44e3ca68
...
@@ -68,6 +68,7 @@ if TYPE_CHECKING:
...
@@ -68,6 +68,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_AWQ
:
bool
=
False
VLLM_USE_TRITON_AWQ
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH
:
bool
=
False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH
:
bool
=
False
VLLM_TREE_DECODING
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -453,6 +454,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -453,6 +454,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_ALLOW_RUNTIME_LORA_UPDATING"
,
"0"
).
strip
().
lower
()
in
(
os
.
environ
.
get
(
"VLLM_ALLOW_RUNTIME_LORA_UPDATING"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
)),
(
"1"
,
"true"
)),
# If set, vLLM will use tree-style speculative decoding.
"VLLM_TREE_DECODING"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_TREE_DECODING"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
))
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/lora/models.py
View file @
44e3ca68
...
@@ -117,8 +117,6 @@ class LoRAModel(AdapterModel):
...
@@ -117,8 +117,6 @@ class LoRAModel(AdapterModel):
pin_memory
=
str
(
device
)
==
"cpu"
and
is_pin_memory_available
()
pin_memory
=
str
(
device
)
==
"cpu"
and
is_pin_memory_available
()
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
{}
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
{}
for
tensor_name
,
tensor
in
tensors
.
items
():
for
tensor_name
,
tensor
in
tensors
.
items
():
if
"lora_A"
not
in
tensor_name
and
"lora_B"
not
in
tensor_name
:
continue
module_name
,
is_lora_a
=
parse_fine_tuned_lora_name
(
tensor_name
)
module_name
,
is_lora_a
=
parse_fine_tuned_lora_name
(
tensor_name
)
if
module_name
not
in
loras
:
if
module_name
not
in
loras
:
lora_embeddings_tensor
=
None
lora_embeddings_tensor
=
None
...
...
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
44e3ca68
import
os
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
import
torch
import
torch
import
torch.jit
import
torch.jit
...
@@ -39,6 +40,8 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -39,6 +40,8 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self
.
_posterior_alpha
=
posterior_alpha
self
.
_posterior_alpha
=
posterior_alpha
super
().
__init__
(
strict_mode
=
strict_mode
)
super
().
__init__
(
strict_mode
=
strict_mode
)
self
.
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
def
forward
(
def
forward
(
self
,
self
,
target_with_bonus_probs
:
torch
.
Tensor
,
target_with_bonus_probs
:
torch
.
Tensor
,
...
@@ -92,7 +95,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -92,7 +95,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self
.
_raise_if_incorrect_input
(
target_with_bonus_probs
,
self
.
_raise_if_incorrect_input
(
target_with_bonus_probs
,
draft_token_ids
,
bonus_token_ids
)
draft_token_ids
,
bonus_token_ids
)
if
cart_candidates
is
None
:
if
not
self
.
tree_decoding
:
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
draft_token_ids
)
draft_token_ids
)
...
@@ -101,6 +104,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -101,6 +104,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
draft_token_ids
,
draft_token_ids
,
bonus_token_ids
)
bonus_token_ids
)
else
:
else
:
assert
cart_candidates
is
not
None
target_probs
=
target_with_bonus_probs
target_probs
=
target_with_bonus_probs
output_token_ids
=
self
.
_evaluate_accepted_tokens_tree_style
(
target_probs
,
output_token_ids
=
self
.
_evaluate_accepted_tokens_tree_style
(
target_probs
,
draft_token_ids
,
draft_token_ids
,
...
@@ -199,7 +203,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -199,7 +203,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
output_token_id_list
=
[]
output_token_id_list
=
[]
accept_length_list
=
accept_length
.
cpu
().
tolist
()
accept_length_list
=
accept_length
.
cpu
().
tolist
()
logger
.
info
(
"accept_length:%s"
,
accept_length_list
)
#
logger.info("accept_length:%s", accept_length_list)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
output_best_candidates
.
append
(
best_candidate
[
i
])
output_best_candidates
.
append
(
best_candidate
[
i
])
accept_lengths
.
append
(
accept_length_list
[
i
])
accept_lengths
.
append
(
accept_length_list
[
i
])
...
...
vllm/spec_decode/medusa_worker.py
View file @
44e3ca68
import
os
import
weakref
import
weakref
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Dict
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Dict
...
@@ -29,6 +30,8 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
...
@@ -29,6 +30,8 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# Lazy initialization list.
# Lazy initialization list.
self
.
_proposer
:
SpeculativeProposer
self
.
_proposer
:
SpeculativeProposer
self
.
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
def
init_device
(
self
):
def
init_device
(
self
):
super
().
init_device
()
super
().
init_device
()
...
@@ -38,7 +41,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
...
@@ -38,7 +41,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# get medusa choices and generate medusa_buffers
# get medusa choices and generate medusa_buffers
self
.
medusa_buffers
=
None
self
.
medusa_buffers
=
None
if
hasattr
(
self
.
model_runner
.
model
,
'medusa_choices'
):
if
self
.
tree_decoding
and
hasattr
(
self
.
model_runner
.
model
,
'medusa_choices'
):
self
.
medusa_choices
=
self
.
model_runner
.
model
.
medusa_choices
self
.
medusa_choices
=
self
.
model_runner
.
model
.
medusa_choices
if
self
.
medusa_choices
is
not
None
:
if
self
.
medusa_choices
is
not
None
:
self
.
medusa_buffers
=
self
.
generate_medusa_buffers
(
self
.
medusa_buffers
=
self
.
generate_medusa_buffers
(
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
44e3ca68
import
os
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
...
@@ -82,9 +83,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
...
@@ -82,9 +83,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_alpha
=
speculative_config
.
typical_acceptance_sampler_posterior_alpha
=
speculative_config
.
typical_acceptance_sampler_posterior_alpha
,
typical_acceptance_sampler_posterior_alpha
,
disable_logprobs
=
speculative_config
.
disable_logprobs
,
disable_logprobs
=
speculative_config
.
disable_logprobs
,
disable_log_stats
=
speculative_config
.
disable_log_stats
,
disable_log_stats
=
speculative_config
.
disable_log_stats
)
tree_style_spec_decoding
=
speculative_config
.
tree_style_spec_decoding
,
)
return
spec_decode_worker
return
spec_decode_worker
...
@@ -126,7 +125,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -126,7 +125,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_alpha
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
disable_logprobs
:
bool
,
disable_logprobs
:
bool
,
disable_log_stats
:
bool
,
disable_log_stats
:
bool
,
tree_style_spec_decoding
:
bool
,
)
->
"SpecDecodeWorker"
:
)
->
"SpecDecodeWorker"
:
allow_zero_draft_token_step
=
True
allow_zero_draft_token_step
=
True
...
@@ -191,8 +189,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -191,8 +189,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_log_stats
=
disable_log_stats
,
disable_log_stats
=
disable_log_stats
,
disable_by_batch_size
=
disable_by_batch_size
,
disable_by_batch_size
=
disable_by_batch_size
,
spec_decode_sampler
=
spec_decode_sampler
,
spec_decode_sampler
=
spec_decode_sampler
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
)
tree_style_spec_decoding
=
tree_style_spec_decoding
)
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -204,7 +201,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -204,7 +201,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
tree_style_spec_decoding
:
bool
=
False
,
):
):
"""
"""
Create a SpecDecodeWorker.
Create a SpecDecodeWorker.
...
@@ -233,7 +229,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -233,7 +229,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
allow_zero_draft_token_step: whether to allow a step where the draft
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
draft model is larger than 1 (TODO: #5814)
tree_style_spec_decoding: Whether to use tree-style generation.
"""
"""
self
.
proposer_worker
=
proposer_worker
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
self
.
scorer_worker
=
scorer_worker
...
@@ -268,7 +263,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -268,7 +263,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
_disable_logprobs
=
disable_logprobs
self
.
_disable_logprobs
=
disable_logprobs
self
.
_disable_log_stats
=
disable_log_stats
self
.
_disable_log_stats
=
disable_log_stats
self
.
tree_
style_spec_decoding
=
tree_style_spec_decoding
self
.
tree_
decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
"""Initialize both scorer and proposer models.
"""Initialize both scorer and proposer models.
...
@@ -285,7 +280,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -285,7 +280,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
spec_decode_sampler
.
init_gpu_tensors
(
self
.
rank
)
self
.
spec_decode_sampler
.
init_gpu_tensors
(
self
.
rank
)
if
not
self
.
tree_
style_spec_
decoding
:
if
not
self
.
tree_decoding
:
self
.
scorer
=
BatchExpansionTop1Scorer
(
self
.
scorer
=
BatchExpansionTop1Scorer
(
scorer_worker
=
self
.
scorer_worker
,
scorer_worker
=
self
.
scorer_worker
,
device
=
self
.
device
,
device
=
self
.
device
,
...
@@ -324,7 +319,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -324,7 +319,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
)
=
True
)
=
True
# tree_style decoding modify probs in _verify_tokens
# tree_style decoding modify probs in _verify_tokens
if
not
self
.
tree_
style_spec_
decoding
:
if
not
self
.
tree_decoding
:
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
)
=
True
should_modify_greedy_probs_inplace
)
=
True
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
...
@@ -535,7 +530,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -535,7 +530,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
not called, meaning that the kv-cache in proposer for requests is not
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
updated, so they cannot enable spec decode in the rest decoding.
"""
"""
if
self
.
tree_
style_spec_
decoding
and
self
.
kvcache_slot_to_be_moved
is
not
None
:
if
self
.
tree_decoding
and
self
.
kvcache_slot_to_be_moved
is
not
None
:
execute_model_req
.
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
execute_model_req
.
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
self
.
kvcache_slot_to_be_moved
=
None
self
.
kvcache_slot_to_be_moved
=
None
...
@@ -560,7 +555,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -560,7 +555,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
hidden_states
,
execute_model_req
.
seq_group_metadata_list
)
hidden_states
,
execute_model_req
.
seq_group_metadata_list
)
# Store logits from target model execution.
# Store logits from target model execution.
if
self
.
tree_
style_spec_
decoding
:
if
self
.
tree_decoding
:
logits
=
sampler_output
.
logits
logits
=
sampler_output
.
logits
if
logits
is
not
None
:
if
logits
is
not
None
:
if
self
.
previous_logits
is
None
:
if
self
.
previous_logits
is
None
:
...
@@ -612,7 +607,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -612,7 +607,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
scorer_worker
.
execute_model
()
self
.
scorer_worker
.
execute_model
()
if
not
data
[
"disable_all_speculation"
]:
if
not
data
[
"disable_all_speculation"
]:
# if not self.tree_
style_spec_
decoding:
# if not self.tree_decoding:
# # Even if num_lookahead_slots is zero, we want to run the
# # Even if num_lookahead_slots is zero, we want to run the
# # proposer model as it may have KV.
# # proposer model as it may have KV.
# #
# #
...
@@ -677,7 +672,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -677,7 +672,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"workers generate no tokens"
)
"workers generate no tokens"
)
# Pass tree attention mask and postions to target model
# Pass tree attention mask and postions to target model
if
self
.
tree_
style_spec_
decoding
:
if
self
.
tree_decoding
:
execute_model_req
.
tree_attn_masks
=
proposals
.
tree_attn_masks
execute_model_req
.
tree_attn_masks
=
proposals
.
tree_attn_masks
execute_model_req
.
tree_position_ids
=
proposals
.
tree_position_ids
execute_model_req
.
tree_position_ids
=
proposals
.
tree_position_ids
...
@@ -695,7 +690,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -695,7 +690,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals
,
execute_model_req
.
num_lookahead_slots
)
proposals
,
execute_model_req
.
num_lookahead_slots
)
# move kv_caches of selected tokens to right positions
# move kv_caches of selected tokens to right positions
if
self
.
tree_
style_spec_
decoding
:
if
self
.
tree_decoding
:
self
.
move_caches
(
execute_model_req
,
select_indices_list
,
accept_lengths
)
self
.
move_caches
(
execute_model_req
,
select_indices_list
,
accept_lengths
)
stage_times
=
(
proposal_timer
.
elapsed_time_ms
/
num_lookahead_slots
,
stage_times
=
(
proposal_timer
.
elapsed_time_ms
/
num_lookahead_slots
,
...
@@ -739,7 +734,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -739,7 +734,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
else
:
else
:
proposal_verifier_probs
=
proposal_scores
.
probs
proposal_verifier_probs
=
proposal_scores
.
probs
if
self
.
tree_
style_spec_
decoding
:
if
self
.
tree_decoding
:
retrieve_indices
=
proposals
.
retrieve_indices
retrieve_indices
=
proposals
.
retrieve_indices
proposal_verifier_probs
=
proposal_verifier_probs
[:,
retrieve_indices
]
proposal_verifier_probs
=
proposal_verifier_probs
[:,
retrieve_indices
]
...
@@ -797,7 +792,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -797,7 +792,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
)
)
# Append output tokens from non-speculative sequences to
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
# the accepted token ids tensor.
if
not
self
.
tree_
style_spec_
decoding
:
if
not
self
.
tree_decoding
:
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
+
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
+
1
).
clone
()
1
).
clone
()
else
:
else
:
...
...
vllm/worker/cache_engine.py
View file @
44e3ca68
...
@@ -8,6 +8,7 @@ from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
...
@@ -8,6 +8,7 @@ from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
,
is_pin_memory_available
)
is_pin_memory_available
)
from
vllm.attention.backends.tree_decoding_utils
import
move_cache
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -103,11 +104,12 @@ class CacheEngine:
...
@@ -103,11 +104,12 @@ class CacheEngine:
def
move_caches
(
self
,
kv_caches
:
List
[
torch
.
Tensor
],
def
move_caches
(
self
,
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dsts
:
torch
.
Tensor
)
->
None
:
src_to_dsts
:
torch
.
Tensor
)
->
None
:
self
.
attn_backend
.
move_cache
(
kv_caches
,
move_cache
(
self
.
attn_backend
,
src_to_dsts
,
kv_caches
,
self
.
cache_config
.
cache_dtype
,
src_to_dsts
,
self
.
num_kv_heads
,
self
.
cache_config
.
cache_dtype
,
self
.
head_size
)
self
.
num_kv_heads
,
self
.
head_size
)
@
staticmethod
@
staticmethod
def
get_cache_block_size
(
def
get_cache_block_size
(
...
...
vllm/worker/model_runner.py
View file @
44e3ca68
...
@@ -198,7 +198,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -198,7 +198,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
lora_requests
.
clear
()
# type: ignore
self
.
lora_requests
.
clear
()
# type: ignore
self
.
prompt_adapter_index_mapping
.
clear
()
# type: ignore
self
.
prompt_adapter_index_mapping
.
clear
()
# type: ignore
self
.
prompt_adapter_prompt_mapping
.
clear
()
# type: ignore
self
.
prompt_adapter_prompt_mapping
.
clear
()
# type: ignore
self
.
tree_attn_masks
[
0
]
=
None
# type: ignore
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -246,9 +245,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -246,9 +245,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
reinit
:
bool
=
False
,
reinit
:
bool
=
False
,
reinit_use_defaults
:
bool
=
False
,
reinit_use_defaults
:
bool
=
False
,
encoder_seq_len
:
int
=
0
,
encoder_seq_len
:
int
=
0
,
# attention mask used in tree-style generation
tree_attn_masks
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
):
):
if
reinit
:
if
reinit
:
assert
len
(
self
.
seq_ids
)
==
len
(
seq_ids
)
# type: ignore
assert
len
(
self
.
seq_ids
)
==
len
(
seq_ids
)
# type: ignore
...
@@ -339,12 +335,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -339,12 +335,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_prompt_mapping
prompt_adapter_prompt_mapping
else
:
else
:
self
.
prompt_adapter_prompt_mapping
.
clear
()
self
.
prompt_adapter_prompt_mapping
.
clear
()
if
tree_attn_masks
:
self
.
tree_attn_masks
=
tree_attn_masks
else
:
self
.
tree_attn_masks
.
clear
()
else
:
else
:
self
.
input_tokens
=
input_tokens
or
[]
self
.
input_tokens
=
input_tokens
or
[]
self
.
input_positions
=
input_positions
or
[]
self
.
input_positions
=
input_positions
or
[]
...
@@ -364,7 +354,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -364,7 +354,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_index_mapping
or
[])
prompt_adapter_index_mapping
or
[])
self
.
prompt_adapter_prompt_mapping
=
(
self
.
prompt_adapter_prompt_mapping
=
(
prompt_adapter_prompt_mapping
or
[])
prompt_adapter_prompt_mapping
or
[])
self
.
tree_attn_masks
=
tree_attn_masks
or
[]
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
multi_modal_inputs
=
multi_modal_inputs
self
.
multi_modal_inputs
=
multi_modal_inputs
...
@@ -380,7 +369,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -380,7 +369,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
input_tokens
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
input_tokens
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
input_positions
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
input_positions
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
tree_attn_masks
=
[
None
for
_
in
range
(
self
.
n_seqs
)]
self
.
mrope_input_positions
=
None
self
.
mrope_input_positions
=
None
self
.
seq_lens
=
[
0
]
*
self
.
n_seqs
self
.
seq_lens
=
[
0
]
*
self
.
n_seqs
self
.
orig_seq_lens
=
[
0
]
*
self
.
n_seqs
self
.
orig_seq_lens
=
[
0
]
*
self
.
n_seqs
...
@@ -469,13 +457,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -469,13 +457,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
sliding_window
+
self
.
block_size
-
1
)
//
self
.
block_size
self
.
sliding_window
+
self
.
block_size
-
1
)
//
self
.
block_size
self
.
block_aligned_sliding_window
=
\
self
.
block_aligned_sliding_window
=
\
self
.
sliding_window_blocks
*
self
.
block_size
self
.
sliding_window_blocks
*
self
.
block_size
if
hasattr
(
self
.
runner
,
"tree_attn_masks"
):
self
.
tree_attn_masks
=
self
.
runner
.
tree_attn_masks
self
.
tree_position_ids
=
self
.
runner
.
tree_position_ids
else
:
self
.
tree_attn_masks
=
None
self
.
tree_position_ids
=
None
self
.
is_encoder_decoder_model
=
self
.
runner
.
model_config
.
is_encoder_decoder_model
self
.
is_encoder_decoder_model
=
self
.
runner
.
model_config
.
is_encoder_decoder_model
...
@@ -853,16 +834,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -853,16 +834,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if
cuda_graph_pad_size
:
if
cuda_graph_pad_size
:
seq_lens
.
extend
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
seq_lens
.
extend
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
# prepare tree attention masks
tree_attention_masks_tensor
=
self
.
tree_attn_masks
if
tree_attention_masks_tensor
is
not
None
:
tree_attention_masks_tensor
=
tree_attention_masks_tensor
.
contiguous
()
input_positions_tensor
=
self
.
tree_position_ids
.
contiguous
()
# Attention metadata.
# Attention metadata.
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
,
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
)
tree_attention_masks_tensor
=
tree_attention_masks_tensor
)
# LoRA data.
# LoRA data.
lora_requests
=
set
()
lora_requests
=
set
()
...
@@ -1033,9 +1007,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1033,9 +1007,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
inter_data_cache
:
Dict
[
int
,
PyObjectCache
]
=
{}
self
.
inter_data_cache
:
Dict
[
int
,
PyObjectCache
]
=
{}
self
.
sampling_metadata_cache
:
SamplingMetadataCache
=
\
self
.
sampling_metadata_cache
:
SamplingMetadataCache
=
\
SamplingMetadataCache
()
SamplingMetadataCache
()
self
.
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
self
.
tree_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
...
@@ -1503,11 +1474,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1503,11 +1474,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
@
property
@
property
def
vocab_size
(
self
)
->
int
:
def
vocab_size
(
self
)
->
int
:
return
self
.
model_config
.
get_vocab_size
()
return
self
.
model_config
.
get_vocab_size
()
def
set_tree_style_args
(
self
,
tree_attn_masks
:
Optional
[
torch
.
Tensor
],
tree_position_ids
:
Optional
[
torch
.
Tensor
]):
self
.
tree_attn_masks
=
tree_attn_masks
self
.
tree_position_ids
=
tree_position_ids
class
ModelRunner
(
GPUModelRunnerBase
[
ModelInputForGPUWithSamplingMetadata
]):
class
ModelRunner
(
GPUModelRunnerBase
[
ModelInputForGPUWithSamplingMetadata
]):
...
...
vllm/worker/worker_base.py
View file @
44e3ca68
...
@@ -31,6 +31,7 @@ class WorkerBase(ABC):
...
@@ -31,6 +31,7 @@ class WorkerBase(ABC):
"""
"""
model_input
:
Optional
[
ModelRunnerInputBase
]
=
None
model_input
:
Optional
[
ModelRunnerInputBase
]
=
None
tree_decoding
=
(
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
)
@
abstractmethod
@
abstractmethod
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
...
@@ -103,18 +104,6 @@ class WorkerBase(ABC):
...
@@ -103,18 +104,6 @@ class WorkerBase(ABC):
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
raise
NotImplementedError
@
property
@
abstractmethod
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
"""
Gets the list of kv caches to pass to the worker's model runner. Each
element in the list is a kv cache corresponding to a particular virtual
engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
"""
raise
NotImplementedError
@
property
@
property
@
abstractmethod
@
abstractmethod
def
cache_engines
(
self
)
->
Optional
[
List
[
CacheEngine
]]:
def
cache_engines
(
self
)
->
Optional
[
List
[
CacheEngine
]]:
...
@@ -138,10 +127,6 @@ class LoraNotSupportedWorkerBase(WorkerBase):
...
@@ -138,10 +127,6 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
@
property
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
return
None
@
property
@
property
def
cache_engines
(
self
)
->
Optional
[
List
[
CacheEngine
]]:
def
cache_engines
(
self
)
->
Optional
[
List
[
CacheEngine
]]:
...
@@ -282,10 +267,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -282,10 +267,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
if
hasattr
(
self
.
model_runner
,
"set_tree_style_args"
):
self
.
model_runner
.
set_tree_style_args
(
tree_attn_masks
=
execute_model_req
.
tree_attn_masks
,
tree_position_ids
=
execute_model_req
.
tree_position_ids
)
model_input
:
ModelRunnerInputBase
=
(
model_input
:
ModelRunnerInputBase
=
(
self
.
model_runner
.
prepare_model_input
(
self
.
model_runner
.
prepare_model_input
(
...
@@ -293,6 +274,17 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -293,6 +274,17 @@ class LocalOrDistributedWorkerBase(WorkerBase):
execute_model_req
.
virtual_engine
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
execute_model_req
.
finished_requests_ids
))
if
self
.
tree_decoding
and
execute_model_req
.
tree_position_ids
is
not
None
and
\
execute_model_req
.
tree_attn_masks
is
not
None
:
if
hasattr
(
model_input
,
"input_positions"
)
and
\
hasattr
(
model_input
,
"attn_metadata"
)
and
\
hasattr
(
model_input
.
attn_metadata
,
"tree_attention_masks_tensor"
):
attn_metadata
=
model_input
.
attn_metadata
attn_metadata
.
tree_attention_masks_tensor
=
execute_model_req
.
tree_attn_masks
.
contiguous
()
model_input
=
dataclasses
.
replace
(
model_input
,
input_positions
=
execute_model_req
.
tree_position_ids
.
contiguous
(),
attn_metadata
=
attn_metadata
)
kwargs
=
extract_previous_hidden_states
(
execute_model_req
)
kwargs
=
extract_previous_hidden_states
(
execute_model_req
)
if
self
.
do_metadata_broadcast
:
if
self
.
do_metadata_broadcast
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment