Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aa7012eb
Unverified
Commit
aa7012eb
authored
Aug 03, 2025
by
Giancarlo Delfin
Committed by
GitHub
Aug 03, 2025
Browse files
Add tree attention backend for v1 (part 1) (#20401)
Signed-off-by:
Giancarlo Delfin
<
gdelfin@meta.com
>
parent
c2e75b3c
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1098 additions
and
25 deletions
+1098
-25
tests/v1/attention/test_attention_backends.py
tests/v1/attention/test_attention_backends.py
+1
-1
tests/v1/attention/utils.py
tests/v1/attention/utils.py
+4
-2
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+4
-3
tests/v1/spec_decode/test_tree_attention.py
tests/v1/spec_decode/test_tree_attention.py
+299
-0
vllm/attention/ops/triton_unified_attention.py
vllm/attention/ops/triton_unified_attention.py
+48
-0
vllm/config.py
vllm/config.py
+13
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-1
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+4
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+1
-0
vllm/v1/attention/backends/tree_attn.py
vllm/v1/attention/backends/tree_attn.py
+452
-0
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+20
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+251
-18
No files found.
tests/v1/attention/test_attention_backends.py
View file @
aa7012eb
...
...
@@ -17,7 +17,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST
=
[
_Backend
.
FLASH_ATTN_VLLM_V1
,
_Backend
.
FLASHINFER_VLLM_V1
,
_Backend
.
FLEX_ATTENTION
,
_Backend
.
TRITON_ATTN_VLLM_V1
_Backend
.
FLEX_ATTENTION
,
_Backend
.
TRITON_ATTN_VLLM_V1
,
_Backend
.
TREE_ATTN
]
# Remove flashinfer from the list if it's not available
...
...
tests/v1/attention/utils.py
View file @
aa7012eb
...
...
@@ -126,6 +126,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
,
_Backend
.
TRITON_ATTN_VLLM_V1
:
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
,
_Backend
.
TREE_ATTN
:
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
,
}
if
backend_name
not
in
backend_map
:
...
...
tests/v1/spec_decode/test_eagle.py
View file @
aa7012eb
...
...
@@ -202,7 +202,9 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
8
])
def
test_propose
(
num_speculative_tokens
):
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
_Backend
.
FLASH_ATTN_VLLM_V1
,
_Backend
.
TREE_ATTN
])
def
test_propose
(
num_speculative_tokens
,
backend
):
# Use GPU device
device
=
torch
.
device
(
current_platform
.
device_type
)
...
...
@@ -301,8 +303,7 @@ def test_propose(num_speculative_tokens):
device
=
device
)
sampling_metadata
=
mock
.
MagicMock
()
attn_metadata_builder_cls
,
_
=
get_attention_backend
(
_Backend
.
FLASH_ATTN_VLLM_V1
)
attn_metadata_builder_cls
,
_
=
get_attention_backend
(
backend
)
attn_metadata_builder
=
attn_metadata_builder_cls
(
kv_cache_spec
=
create_standard_kv_cache_spec
(
proposer
.
vllm_config
),
layer_names
=
proposer
.
attn_layer_names
,
...
...
tests/v1/spec_decode/test_tree_attention.py
0 → 100644
View file @
aa7012eb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
typing
import
Optional
import
torch
from
tests.v1.attention.utils
import
(
_Backend
,
create_standard_kv_cache_spec
,
create_vllm_config
,
get_attention_backend
)
from
vllm.config
import
ParallelConfig
,
SpeculativeConfig
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
class
MockAttentionLayer
(
torch
.
nn
.
Module
):
_q_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
_k_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
_v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
x
):
return
x
def
forward_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
seqlen_k
:
int
,
backend
:
_Backend
,
spec_token_tree
:
Optional
[
str
]
=
None
,
num_spec_tokens
:
int
=
0
,
)
->
torch
.
Tensor
:
batch_size
,
q_len
,
num_heads
,
dim_per_head
=
q
.
shape
num_kv_heads
=
k
.
shape
[
-
2
]
# Initialize the query and KV sequence lengths.
query_start_loc
=
q_len
*
torch
.
arange
(
batch_size
+
1
,
device
=
q
.
device
,
dtype
=
torch
.
int32
)
query_lens
=
torch
.
diff
(
query_start_loc
)
seq_lens
=
torch
.
full
(
(
batch_size
,
),
seqlen_k
,
device
=
q
.
device
,
dtype
=
torch
.
int32
,
)
context_lens
=
seq_lens
-
query_lens
max_query_len
=
q_len
num_actual_tokens
=
query_start_loc
[
-
1
]
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
layer
=
MockAttentionLayer
()
# Build common metadata.
model_name
=
"meta-llama/Meta-Llama-3-8B"
builder_cls
,
impl_cls
=
get_attention_backend
(
backend
)
vllm_config
=
create_vllm_config
(
model_name
=
model_name
,
max_model_len
=
max
(
seq_lens
))
if
spec_token_tree
is
not
None
:
# Create speculative config if token tree is specified.
vllm_config
.
speculative_config
=
SpeculativeConfig
(
target_model_config
=
vllm_config
.
model_config
,
target_parallel_config
=
ParallelConfig
(),
model
=
model_name
,
method
=
"eagle"
,
num_speculative_tokens
=
num_spec_tokens
,
speculative_token_tree
=
spec_token_tree
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
builder
=
builder_cls
(
kv_cache_spec
,
[],
vllm_config
,
q
.
device
)
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc
.
cpu
(),
seq_lens
=
seq_lens
,
seq_lens_cpu
=
seq_lens
.
cpu
(),
num_computed_tokens_cpu
=
context_lens
.
cpu
(),
num_reqs
=
batch_size
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
block_table_tensor
=
block_table
,
slot_mapping
=
slot_mapping
,
)
# Build attention metadata.
attn_metadata
=
builder
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
)
# Initialize the backend implementation.
instance
=
impl_cls
(
num_heads
=
num_heads
,
head_size
=
dim_per_head
,
scale
=
softmax_scale
,
num_kv_heads
=
num_kv_heads
,
alibi_slopes
=
None
,
sliding_window
=
None
,
kv_cache_dtype
=
"auto"
,
)
# Run forward pass and return output.
query
=
q
.
view
(
-
1
,
num_heads
,
dim_per_head
)
key
=
k
.
view
(
-
1
,
num_kv_heads
,
dim_per_head
)
value
=
v
.
view
(
-
1
,
num_kv_heads
,
dim_per_head
)
output
=
torch
.
empty_like
(
query
)
return
instance
.
forward
(
layer
=
layer
,
query
=
query
,
key
=
key
,
value
=
value
,
kv_cache
=
kv_cache
.
clone
(),
attn_metadata
=
attn_metadata
,
output
=
output
,
)
def
test_tree_attn_correctness
()
->
None
:
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed_all
(
42
)
device
=
"cuda"
tree_attn_masks
=
{
# Chain.
"[(0,), (0, 0), (0, 0, 0)]"
:
torch
.
tensor
(
[
[
1
,
0
,
0
,
0
],
[
1
,
1
,
0
,
0
],
[
1
,
1
,
1
,
0
],
[
1
,
1
,
1
,
1
],
],
device
=
device
,
dtype
=
torch
.
int32
,
),
# Tree.
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]"
:
torch
.
tensor
(
[
[
1
,
0
,
0
,
0
,
0
,
0
,
0
],
[
1
,
1
,
0
,
0
,
0
,
0
,
0
],
[
1
,
0
,
1
,
0
,
0
,
0
,
0
],
[
1
,
1
,
0
,
1
,
0
,
0
,
0
],
[
1
,
1
,
0
,
0
,
1
,
0
,
0
],
[
1
,
0
,
1
,
0
,
0
,
1
,
0
],
[
1
,
0
,
1
,
0
,
0
,
0
,
1
],
],
device
=
device
,
dtype
=
torch
.
int32
,
),
}
dim_per_head
=
128
num_kv_heads
=
2
block_size
=
128
max_sequence_length
=
8192
randomize_blocks
=
True
for
batch_size
in
[
1
,
16
,
32
]:
for
num_heads
in
[
2
,
4
]:
for
sequence_position
in
[
16
,
1024
,
2048
]:
for
spec_token_tree
,
tree_attn_mask
in
tree_attn_masks
.
items
():
# Assert that the number of heads is divisible
# by the number of KV heads.
assert
num_heads
%
num_kv_heads
==
0
# Initialize q, k, and v.
tree_size_q
=
tree_attn_mask
.
shape
[
0
]
seqlen_k
=
sequence_position
+
tree_size_q
q
=
torch
.
randn
(
(
batch_size
,
tree_size_q
,
num_heads
,
dim_per_head
),
device
=
device
,
dtype
=
torch
.
bfloat16
,
)
k
=
torch
.
randn
(
(
batch_size
,
tree_size_q
,
num_kv_heads
,
dim_per_head
),
device
=
device
,
dtype
=
torch
.
bfloat16
,
)
v
=
torch
.
randn
(
(
batch_size
,
tree_size_q
,
num_kv_heads
,
dim_per_head
),
device
=
device
,
dtype
=
torch
.
bfloat16
,
)
# Setup the block table and KV cache for paged KV.
assert
max_sequence_length
%
block_size
==
0
max_blocks_per_batch
=
max_sequence_length
//
block_size
kv_cache
=
torch
.
randn
(
(
2
,
batch_size
*
max_blocks_per_batch
,
block_size
,
num_kv_heads
,
dim_per_head
,
),
device
=
q
.
device
,
dtype
=
torch
.
bfloat16
,
)
num_alloc_blocks_per_batch
=
math
.
ceil
(
seqlen_k
/
block_size
)
block_table
=
torch
.
zeros
(
(
batch_size
,
max_blocks_per_batch
),
device
=
q
.
device
,
dtype
=
torch
.
int32
,
)
block_ids
=
torch
.
arange
(
0
,
batch_size
*
num_alloc_blocks_per_batch
,
device
=
q
.
device
,
dtype
=
torch
.
int32
,
)
if
randomize_blocks
:
# Randomize the block ids.
block_ids
=
block_ids
[
torch
.
randperm
(
block_ids
.
numel
())]
block_table
[:,
:
num_alloc_blocks_per_batch
]
=
block_ids
.
view
(
-
1
,
num_alloc_blocks_per_batch
)
# Setup the slot mapping for the input KVs.
tree_positions
=
sequence_position
+
torch
.
arange
(
0
,
tree_size_q
,
device
=
q
.
device
,
dtype
=
torch
.
int64
,
).
repeat
(
batch_size
,
1
)
tree_slot_mapping
=
_gen_slot_mapping
(
tree_positions
,
block_table
,
block_size
)
# Compute attention for the tree.
tree_attn_output
=
forward_attention
(
q
=
q
,
k
=
k
,
v
=
v
,
kv_cache
=
kv_cache
,
block_table
=
block_table
,
slot_mapping
=
tree_slot_mapping
,
seqlen_k
=
seqlen_k
,
backend
=
_Backend
.
TREE_ATTN
,
spec_token_tree
=
spec_token_tree
,
num_spec_tokens
=
tree_size_q
-
1
,
).
view
(
batch_size
,
-
1
,
num_heads
,
dim_per_head
)
# Verify that the chain attention output for each
# branch of the tree (computed using FA3) matches
# the tree attention output.
for
q_index
in
range
(
tree_size_q
):
# Get the q, k, and v for the branch.
branch_mask
=
tree_attn_mask
[
q_index
,
:]
branch_indices
=
torch
.
nonzero
(
branch_mask
,
as_tuple
=
True
)[
0
]
q_len
=
branch_indices
.
shape
[
0
]
q_branch
=
q
[:,
branch_indices
]
k_branch
=
k
[:,
branch_indices
]
v_branch
=
v
[:,
branch_indices
]
# Setup slot mapping for the branch.
branch_positions
=
sequence_position
+
torch
.
arange
(
0
,
q_len
,
device
=
q
.
device
,
dtype
=
torch
.
int64
,
).
repeat
(
batch_size
,
1
)
branch_slot_mapping
=
_gen_slot_mapping
(
branch_positions
,
block_table
,
block_size
)
# Compute flash attention for the branch.
flash_attn_output
=
forward_attention
(
q
=
q_branch
,
k
=
k_branch
,
v
=
v_branch
,
kv_cache
=
kv_cache
,
block_table
=
block_table
,
slot_mapping
=
branch_slot_mapping
,
seqlen_k
=
sequence_position
+
q_len
,
backend
=
_Backend
.
FLASH_ATTN_VLLM_V1
,
).
view
(
batch_size
,
-
1
,
num_heads
,
dim_per_head
)
# Compare the outputs.
assert
torch
.
allclose
(
tree_attn_output
[:,
branch_indices
],
flash_attn_output
,
atol
=
7.81e-3
,
),
(
f
"outputs are not close for "
f
"batch_size:
{
batch_size
}
, "
f
"num_heads:
{
num_heads
}
, "
f
"sequence_position:
{
sequence_position
}
, "
f
"tree_attn_mask:
{
tree_attn_mask
}
, "
f
"q_index:
{
q_index
}
."
)
def
_gen_slot_mapping
(
positions
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
block_size
:
int
):
block_indices
=
positions
//
block_size
blocks
=
block_table
.
gather
(
dim
=
1
,
index
=
block_indices
)
return
(
blocks
*
block_size
+
positions
%
block_size
).
view
(
-
1
)
vllm/attention/ops/triton_unified_attention.py
View file @
aa7012eb
...
...
@@ -55,6 +55,7 @@ def kernel_unified_attention_2d(
block_tables_ptr
,
# [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr
,
# [num_seqs]
alibi_slopes_ptr
,
# [num_query_heads]
qq_bias_ptr
,
# [num_query_tokens, num_query_tokens]
scale
,
# float32
k_scale
,
# float32
v_scale
,
# float32
...
...
@@ -66,10 +67,12 @@ def kernel_unified_attention_2d(
query_stride_1
:
tl
.
int64
,
# int, should be equal to head_size
output_stride_0
:
tl
.
int64
,
# int
output_stride_1
:
tl
.
int64
,
# int, should be equal to head_size
qq_bias_stride_0
:
tl
.
int64
,
# int
BLOCK_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE_PADDED
:
tl
.
constexpr
,
# int, must be power of 2
USE_ALIBI_SLOPES
:
tl
.
constexpr
,
# bool
USE_QQ_BIAS
:
tl
.
constexpr
,
# bool
USE_SOFTCAP
:
tl
.
constexpr
,
# bool
SLIDING_WINDOW
:
tl
.
constexpr
,
# int
stride_k_cache_0
:
tl
.
int64
,
# int
...
...
@@ -144,6 +147,11 @@ def kernel_unified_attention_2d(
mask
=
query_mask_1
,
other
=
0.0
)
# query-query attention bias
if
USE_QQ_BIAS
:
qq_bias_row_ptrs
=
(
qq_bias_ptr
+
query_pos
[:,
None
]
*
qq_bias_stride_0
)
# shape: [BLOCK_M]
# compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx)
max_seq_prefix_len
=
context_len
+
q_block_local_idx
*
BLOCK_Q
+
(
...
...
@@ -223,6 +231,18 @@ def kernel_unified_attention_2d(
if
USE_ALIBI_SLOPES
:
S
+=
alibi_slope
[:,
None
]
*
(
seq_offset
-
context_len
)
if
USE_QQ_BIAS
:
# compute key positions relative to query section
key_rel_pos
=
seq_offset
-
context_len
# shape: [BLOCK_SIZE]
# load bias only for keys that correspond to queries
is_query_key
=
key_rel_pos
>=
0
and
key_rel_pos
<
qq_bias_stride_0
qq_bias
=
tl
.
load
(
qq_bias_row_ptrs
+
key_rel_pos
[
None
,
:],
mask
=
is_query_key
[
None
,
:],
# avoid OOB for context keys
other
=
0.0
,
)
S
+=
qq_bias
# compute running maximum
# m_j : (BLOCK_M,)
m_j
=
tl
.
maximum
(
M
,
tl
.
max
(
S
,
axis
=
1
))
...
...
@@ -275,6 +295,7 @@ def kernel_unified_attention_3d(
block_tables_ptr
,
# [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr
,
# [num_seqs]
alibi_slopes_ptr
,
# [num_query_heads]
qq_bias_ptr
,
# [num_query_tokens, num_query_tokens]
scale
,
# float32
k_scale
,
# float32
v_scale
,
# float32
...
...
@@ -284,10 +305,12 @@ def kernel_unified_attention_3d(
block_table_stride
:
tl
.
int64
,
# int
query_stride_0
:
tl
.
int64
,
# int
query_stride_1
:
tl
.
int64
,
# int, should be equal to head_size
qq_bias_stride_0
:
tl
.
int64
,
# int
BLOCK_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE_PADDED
:
tl
.
constexpr
,
# int, must be power of 2
USE_ALIBI_SLOPES
:
tl
.
constexpr
,
# bool
USE_QQ_BIAS
:
tl
.
constexpr
,
# bool
USE_SOFTCAP
:
tl
.
constexpr
,
# bool
SLIDING_WINDOW
:
tl
.
constexpr
,
# int
stride_k_cache_0
:
tl
.
int64
,
# int
...
...
@@ -373,6 +396,11 @@ def kernel_unified_attention_3d(
mask
=
query_mask_1
,
other
=
0.0
)
# query-query attention bias
if
USE_QQ_BIAS
:
qq_bias_row_ptrs
=
(
qq_bias_ptr
+
query_pos
[:,
None
]
*
qq_bias_stride_0
)
# shape: [BLOCK_M]
num_blocks
=
cdiv_fn
(
seq_len
,
BLOCK_SIZE
)
# iterate through tiles within current segment
...
...
@@ -442,6 +470,18 @@ def kernel_unified_attention_3d(
if
USE_ALIBI_SLOPES
:
S
+=
alibi_slope
[:,
None
]
*
(
seq_offset
-
context_len
)
if
USE_QQ_BIAS
:
# compute key positions relative to query section
key_rel_pos
=
seq_offset
-
context_len
# shape: [BLOCK_SIZE]
# load bias only for keys that correspond to queries
is_query_key
=
key_rel_pos
>=
0
and
key_rel_pos
<
qq_bias_stride_0
qq_bias
=
tl
.
load
(
qq_bias_row_ptrs
+
key_rel_pos
[
None
,
:],
mask
=
is_query_key
[
None
,
:],
# avoid OOB for context keys
other
=
0.0
,
)
S
+=
qq_bias
# compute running maximum
# m_j : (BLOCK_M,)
m_j
=
tl
.
maximum
(
M
,
tl
.
max
(
S
,
axis
=
1
))
...
...
@@ -586,6 +626,7 @@ def unified_attention(
k_descale
,
v_descale
,
alibi_slopes
=
None
,
qq_bias
=
None
,
):
assert
causal
,
"Only causal attention is supported"
assert
q_descale
is
None
,
"Q scales not supported"
...
...
@@ -595,6 +636,7 @@ def unified_attention(
"Block size must be at least 32 for fp8"
use_alibi_slopes
=
alibi_slopes
is
not
None
use_qq_bias
=
qq_bias
is
not
None
block_size
=
v
.
shape
[
1
]
num_seqs
=
len
(
seqused_k
)
...
...
@@ -630,6 +672,7 @@ def unified_attention(
block_tables_ptr
=
block_table
,
seq_lens_ptr
=
seqused_k
,
alibi_slopes_ptr
=
alibi_slopes
,
qq_bias_ptr
=
qq_bias
,
scale
=
softmax_scale
,
k_scale
=
k_descale
,
v_scale
=
v_descale
,
...
...
@@ -641,10 +684,12 @@ def unified_attention(
query_stride_1
=
q
.
stride
(
1
),
output_stride_0
=
out
.
stride
(
0
),
output_stride_1
=
out
.
stride
(
1
),
qq_bias_stride_0
=
qq_bias
.
stride
(
0
)
if
use_qq_bias
else
0
,
BLOCK_SIZE
=
block_size
,
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
USE_QQ_BIAS
=
use_qq_bias
,
USE_SOFTCAP
=
(
softcap
>
0
),
SLIDING_WINDOW
=
(
1
+
window_size
[
0
]),
stride_k_cache_0
=
k
.
stride
(
0
),
...
...
@@ -699,6 +744,7 @@ def unified_attention(
block_tables_ptr
=
block_table
,
seq_lens_ptr
=
seqused_k
,
alibi_slopes_ptr
=
alibi_slopes
,
qq_bias_ptr
=
qq_bias
,
scale
=
softmax_scale
,
k_scale
=
k_descale
,
v_scale
=
v_descale
,
...
...
@@ -708,10 +754,12 @@ def unified_attention(
block_table_stride
=
block_table
.
stride
(
0
),
query_stride_0
=
q
.
stride
(
0
),
query_stride_1
=
q
.
stride
(
1
),
qq_bias_stride_0
=
qq_bias
.
stride
(
0
)
if
use_qq_bias
else
0
,
BLOCK_SIZE
=
block_size
,
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
USE_QQ_BIAS
=
use_qq_bias
,
USE_SOFTCAP
=
(
softcap
>
0
),
SLIDING_WINDOW
=
(
1
+
window_size
[
0
]),
stride_k_cache_0
=
k
.
stride
(
0
),
...
...
vllm/config.py
View file @
aa7012eb
...
...
@@ -3049,6 +3049,19 @@ class SpeculativeConfig:
f
"num_speculative_tokens:
{
self
.
num_speculative_tokens
}
"
f
" must be divisible by
{
n_predict
=
}
"
)
if
self
.
speculative_token_tree
is
None
:
# Generate chain of tokens.
self
.
speculative_token_tree
=
str
([
(
i
+
1
)
*
(
0
,
)
for
i
in
range
(
self
.
num_speculative_tokens
)
])
else
:
# Sort the token tree breadth-first.
tree_choices
=
ast
.
literal_eval
(
self
.
speculative_token_tree
)
self
.
speculative_token_tree
=
str
(
sorted
(
tree_choices
,
key
=
lambda
t
:
(
len
(
t
),
t
)))
self
.
draft_tensor_parallel_size
=
\
SpeculativeConfig
.
_verify_and_get_draft_tp
(
self
.
target_parallel_config
,
...
...
vllm/engine/arg_utils.py
View file @
aa7012eb
...
...
@@ -1454,7 +1454,6 @@ class EngineArgs:
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or deepseek_mtp."
)
# No XFormers so far.
V1_BACKENDS
=
[
"FLASH_ATTN_VLLM_V1"
,
"FLASH_ATTN"
,
...
...
@@ -1469,6 +1468,7 @@ class EngineArgs:
"ROCM_AITER_MLA"
,
"TORCH_SDPA_VLLM_V1"
,
"FLEX_ATTENTION"
,
"TREE_ATTN"
,
]
if
(
envs
.
is_set
(
"VLLM_ATTENTION_BACKEND"
)
and
envs
.
VLLM_ATTENTION_BACKEND
not
in
V1_BACKENDS
):
...
...
vllm/platforms/cuda.py
View file @
aa7012eb
...
...
@@ -270,6 +270,7 @@ class CudaPlatformBase(Platform):
FLEX_ATTENTION_V1
=
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
# noqa: E501
TRITON_ATTN_VLLM_V1
=
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
# noqa: E501
FLASH_ATTN_V1
=
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
# noqa: E501
TREE_ATTN_V1
=
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
# noqa: E501
if
selected_backend
==
_Backend
.
FLASHINFER
:
logger
.
info_once
(
"Using FlashInfer backend on V1 engine."
)
...
...
@@ -287,6 +288,9 @@ class CudaPlatformBase(Platform):
elif
selected_backend
==
_Backend
.
FLASH_ATTN
:
logger
.
info_once
(
"Using Flash Attention backend on V1 engine."
)
return
FLASH_ATTN_V1
elif
selected_backend
==
_Backend
.
TREE_ATTN
:
logger
.
info_once
(
"Using Tree Attention backend on V1 engine."
)
return
TREE_ATTN_V1
from
vllm.attention.selector
import
is_attn_backend_supported
...
...
vllm/platforms/interface.py
View file @
aa7012eb
...
...
@@ -62,6 +62,7 @@ class _Backend(enum.Enum):
DIFFERENTIAL_FLASH_ATTN
=
enum
.
auto
()
NO_ATTENTION
=
enum
.
auto
()
FLEX_ATTENTION
=
enum
.
auto
()
TREE_ATTN
=
enum
.
auto
()
class
PlatformEnum
(
enum
.
Enum
):
...
...
vllm/v1/attention/backends/tree_attn.py
0 → 100644
View file @
aa7012eb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with TreeAttention."""
import
ast
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.ops.triton_unified_attention
import
unified_attention
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
reorder_batch_to_split_decodes_and_prefills
,
split_decodes_and_prefills
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm
import
_custom_ops
as
ops
logger
=
init_logger
(
__name__
)
class
TreeAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
@
classmethod
def
validate_head_size
(
cls
,
head_size
:
int
)
->
None
:
supported_head_sizes
=
cls
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
attn_type
=
cls
.
__name__
.
removesuffix
(
"Backend"
)
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by
{
attn_type
}
. "
f
"Supported head sizes are:
{
supported_head_sizes
}
. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@
staticmethod
def
get_name
()
->
str
:
return
"TREE_ATTN_VLLM_V1"
@
staticmethod
def
get_impl_cls
()
->
type
[
"TreeAttentionImpl"
]:
return
TreeAttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
type
[
"AttentionMetadata"
]:
return
TreeAttentionMetadata
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
int
,
...]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_builder_cls
()
->
type
[
"TreeAttentionMetadataBuilder"
]:
return
TreeAttentionMetadataBuilder
@
staticmethod
def
use_cascade_attention
(
*
args
,
**
kwargs
)
->
bool
:
return
False
@
dataclass
class
TreeAttentionMetadata
:
num_actual_tokens
:
int
# Number of tokens excluding padding.
max_query_len
:
int
query_start_loc
:
torch
.
Tensor
max_seq_len
:
int
seq_lens
:
torch
.
Tensor
block_table
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
num_prefill_tokens
:
int
=
0
num_decode_tokens
:
int
=
0
num_prefills
:
int
=
0
num_decodes
:
int
=
0
tree_attn_bias
:
Optional
[
torch
.
Tensor
]
=
None
# Cached Prefill/decode metadata.
_cached_prefill_metadata
:
Optional
[
"TreeAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"TreeAttentionMetadata"
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"TreeAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
# Recover cached prefill-phase attention
# metadata structure
return
self
.
_cached_prefill_metadata
q_start_loc
=
self
.
query_start_loc
[
self
.
num_decodes
:]
q_seqlens
=
torch
.
diff
(
q_start_loc
)
kv_seqlens
=
self
.
seq_lens
[
self
.
num_decodes
:]
# Construct & cache prefill-phase attention metadata structure
self
.
_cached_prefill_metadata
=
TreeAttentionMetadata
(
num_actual_tokens
=
self
.
num_prefill_tokens
,
max_query_len
=
int
(
q_seqlens
.
max
().
item
()),
query_start_loc
=
q_start_loc
-
q_start_loc
[
0
],
max_seq_len
=
int
(
kv_seqlens
.
max
().
item
()),
seq_lens
=
kv_seqlens
,
block_table
=
self
.
block_table
[
self
.
num_decodes
:],
slot_mapping
=
self
.
slot_mapping
[
self
.
num_decode_tokens
:],
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"TreeAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
# Recover cached decode-phase attention
# metadata structure
return
self
.
_cached_decode_metadata
q_start_loc
=
self
.
query_start_loc
[:
self
.
num_decodes
+
1
]
q_seqlens
=
torch
.
diff
(
q_start_loc
)
kv_seqlens
=
self
.
seq_lens
[:
self
.
num_decodes
]
# Construct & cache decode-phase attention metadata structure
self
.
_cached_decode_metadata
=
TreeAttentionMetadata
(
num_actual_tokens
=
self
.
num_decode_tokens
,
max_query_len
=
int
(
q_seqlens
.
max
().
item
()),
query_start_loc
=
q_start_loc
,
max_seq_len
=
int
(
kv_seqlens
.
max
().
item
()),
seq_lens
=
kv_seqlens
,
block_table
=
self
.
block_table
[:
self
.
num_decodes
],
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_decode_tokens
],
tree_attn_bias
=
self
.
tree_attn_bias
,
)
return
self
.
_cached_decode_metadata
class
TreeAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
TreeAttentionMetadata
]):
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_size
=
kv_cache_spec
.
block_size
spec_config
=
vllm_config
.
speculative_config
spec_token_tree
=
(
spec
:
=
spec_config
)
and
spec
.
speculative_token_tree
tree_choices
:
list
[
tuple
[
int
,
...]]
=
(
ast
.
literal_eval
(
spec_token_tree
)
if
spec_token_tree
is
not
None
else
[(
0
,
)])
# Construct the tree attention bias.
depth_counts
=
_get_depth_counts
(
tree_choices
)
self
.
tree_attn_bias
=
_prepare_tree_attn_bias
(
tree_choices
,
depth_counts
,
dtype
=
torch
.
float32
,
device
=
device
,
)
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
return
reorder_batch_to_split_decodes_and_prefills
(
input_batch
,
scheduler_output
,
decode_threshold
=
self
.
tree_attn_bias
.
shape
[
0
])
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
TreeAttentionMetadata
:
decode_threshold
=
self
.
tree_attn_bias
.
shape
[
0
]
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
decode_threshold
))
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
q_start_loc
=
common_attn_metadata
.
query_start_loc
max_query_len
=
common_attn_metadata
.
max_query_len
kv_seqlens
=
common_attn_metadata
.
seq_lens
max_seq_len
=
int
(
common_attn_metadata
.
seq_lens_cpu
.
max
())
block_table
=
common_attn_metadata
.
block_table_tensor
slot_mapping
=
common_attn_metadata
.
slot_mapping
return
TreeAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_prefills
=
num_prefills
,
num_decodes
=
num_decodes
,
max_query_len
=
max_query_len
,
query_start_loc
=
q_start_loc
,
max_seq_len
=
max_seq_len
,
seq_lens
=
kv_seqlens
,
block_table
=
block_table
,
slot_mapping
=
slot_mapping
,
tree_attn_bias
=
self
.
tree_attn_bias
,
)
def
build_for_drafting
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
draft_index
:
int
,
)
->
TreeAttentionMetadata
:
# Cache the original tree attention bias.
orig_tree_attn_bias
=
self
.
tree_attn_bias
if
draft_index
==
0
:
# Use prefill for drafting at the root level.
self
.
tree_attn_bias
=
torch
.
empty
(
0
)
else
:
# Slice the tree attention bias for drafting.
query_len
=
common_attn_metadata
.
max_query_len
start
,
end
=
draft_index
,
draft_index
+
query_len
self
.
tree_attn_bias
=
self
.
tree_attn_bias
[
start
:
end
,
start
:
end
].
contiguous
()
# Build attention bias.
attn_metadata
=
self
.
build
(
0
,
common_attn_metadata
,
fast_build
=
True
)
# Reset the tree attention bias to the original value.
self
.
tree_attn_bias
=
orig_tree_attn_bias
return
attn_metadata
def
_get_depth_counts
(
sorted_tree_choices
:
list
[
tuple
[
int
,
...]])
->
list
[
int
]:
# Count the number of choices at each depth of the tree.
depth_counts
=
[]
prev_depth
=
0
for
path
in
sorted_tree_choices
:
depth
=
len
(
path
)
if
depth
!=
prev_depth
:
depth_counts
.
append
(
0
)
depth_counts
[
depth
-
1
]
+=
1
prev_depth
=
depth
return
depth_counts
def
_prepare_tree_attn_bias
(
sorted_tree_choices
:
list
[
tuple
[
int
,
...]],
depth_counts
:
list
[
int
],
dtype
:
Optional
[
torch
.
dtype
],
device
:
Optional
[
torch
.
device
],
)
->
torch
.
Tensor
:
# +1 comes from the additional root node.
tree_len
=
len
(
sorted_tree_choices
)
+
1
tree_attn_mask
=
torch
.
full
((
tree_len
,
tree_len
),
-
torch
.
inf
,
device
=
device
,
dtype
=
dtype
)
# Set diagonal to all zeros. Each token should
# attend to itself.
mask_val
=
0
for
i
in
range
(
tree_len
):
tree_attn_mask
[
i
,
i
]
=
mask_val
# Set root to all zeros. All tokens attend to it.
tree_attn_mask
[:,
0
]
=
mask_val
# Set all ancestors to zeros.
start
=
0
for
i
in
range
(
len
(
depth_counts
)):
for
j
in
range
(
depth_counts
[
i
]):
cur_tree_choice
=
sorted_tree_choices
[
start
+
j
]
# Retrieve ancestor position.
if
len
(
cur_tree_choice
)
==
1
:
continue
ancestor_idx
=
[]
for
c
in
range
(
len
(
cur_tree_choice
)
-
1
):
ancestor_idx
.
append
(
sorted_tree_choices
.
index
(
cur_tree_choice
[:
c
+
1
])
+
1
)
tree_attn_mask
[
j
+
start
+
1
,
ancestor_idx
]
=
mask_val
start
+=
depth_counts
[
i
]
return
tree_attn_mask
class
TreeAttentionImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
list
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"TreeAttention does not support block-sparse attention."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
if
logits_soft_cap
is
None
:
# Setting logits_soft_cap to 0 means no soft cap.
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
if
sliding_window
is
None
:
self
.
sliding_window
=
(
-
1
,
-
1
)
else
:
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
TreeAttentionBackend
.
validate_head_size
(
head_size
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TreeAttentionImpl."
)
def
forward
(
self
,
layer
:
torch
.
nn
.
Module
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
TreeAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with TreeAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for TreeAttentionImpl"
)
if
attn_metadata
is
None
:
# Profiling run.
return
output
# Cache the input KVs.
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
descale_shape
=
(
attn_metadata
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
unified_attention
(
q
=
query
[
num_decode_tokens
:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[
num_decode_tokens
:
num_actual_tokens
],
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
seqused_k
=
prefill_meta
.
seq_lens
,
max_seqlen_k
=
prefill_meta
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
block_table
=
prefill_meta
.
block_table
,
softcap
=
self
.
logits_soft_cap
,
q_descale
=
None
,
# Not supported
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
unified_attention
(
q
=
query
[:
num_decode_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_decode_tokens
],
cu_seqlens_q
=
decode_meta
.
query_start_loc
,
max_seqlen_q
=
decode_meta
.
max_query_len
,
seqused_k
=
decode_meta
.
seq_lens
,
max_seqlen_k
=
decode_meta
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
qq_bias
=
decode_meta
.
tree_attn_bias
,
window_size
=
self
.
sliding_window
,
block_table
=
decode_meta
.
block_table
,
softcap
=
self
.
logits_soft_cap
,
q_descale
=
None
,
# Not supported
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
return
output
vllm/v1/attention/backends/utils.py
View file @
aa7012eb
...
...
@@ -214,6 +214,26 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
return
self
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
def
build_for_drafting
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
draft_index
:
int
,
)
->
M
:
"""
Build attention metadata for draft model. Uses build by default.
Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
When speculating a chain of tokens, this index refers to the
draft attempt for the i-th token.
For tree-based attention, this index instead refers to the
draft attempt for the i-th level in the tree of tokens.
"""
return
self
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
fast_build
=
True
)
def
use_cascade_attention
(
self
,
common_prefix_len
:
int
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
aa7012eb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ast
from
dataclasses
import
replace
from
typing
import
Optional
import
numpy
as
np
...
...
@@ -17,6 +19,8 @@ from vllm.model_executor.models import supports_multimodal
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.utils
import
is_pin_memory_available
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.tree_attn
import
(
TreeAttentionMetadata
,
TreeAttentionMetadataBuilder
)
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
...
@@ -74,18 +78,52 @@ class EagleProposer:
(
self
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
device
=
device
)
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
arange
=
torch
.
arange
(
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self
.
arange
=
torch
.
arange
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
1
,
max_batch_size
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
,
)
self
.
inputs_embeds
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
device
=
device
)
# Parse the speculative token tree.
spec_token_tree
=
self
.
speculative_config
.
speculative_token_tree
self
.
tree_choices
:
list
[
tuple
[
int
,
...]]
=
ast
.
literal_eval
(
spec_token_tree
)
tree_depth
=
len
(
self
.
tree_choices
[
-
1
])
# Precompute per-level properties of the tree.
num_drafts_per_level
=
[
0
]
*
tree_depth
for
node
in
self
.
tree_choices
:
num_drafts_per_level
[
len
(
node
)
-
1
]
+=
1
self
.
cu_drafts_per_level
=
[
num_drafts_per_level
[
0
]]
self
.
child_drafts_per_level
=
[
num_drafts_per_level
[
0
]]
for
level
in
range
(
1
,
tree_depth
):
self
.
cu_drafts_per_level
.
append
(
self
.
cu_drafts_per_level
[
-
1
]
+
num_drafts_per_level
[
level
])
self
.
child_drafts_per_level
.
append
(
num_drafts_per_level
[
level
]
//
num_drafts_per_level
[
level
-
1
])
# Find the first level where the tree branches off into one or more
# children.
self
.
first_branching_level
=
None
for
level
in
range
(
tree_depth
):
if
self
.
cu_drafts_per_level
[
level
]
>
level
+
1
:
self
.
first_branching_level
=
level
break
# Precompute draft position offsets in flattened tree.
self
.
tree_draft_pos_offsets
=
torch
.
arange
(
1
,
len
(
self
.
tree_choices
)
+
1
,
device
=
device
,
dtype
=
torch
.
int32
,
).
repeat
(
max_batch_size
,
1
)
def
propose
(
self
,
# [num_tokens]
...
...
@@ -120,11 +158,9 @@ class EagleProposer:
assert
self
.
runner
is
not
None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata
=
self
.
runner
.
attn_metadata_builders
[
0
].
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
fast_build
=
True
,
)
attn_metadata
=
self
.
runner
.
attn_metadata_builders
[
0
].
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
0
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
...
...
@@ -167,6 +203,22 @@ class EagleProposer:
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
positions
=
target_positions
[
last_token_indices
]
hidden_states
=
hidden_states
[
last_token_indices
]
if
self
.
first_branching_level
==
0
:
# Branching has occurred at the root level. Draft using tree
# attention.
draft_token_ids_list
=
self
.
propose_tree
(
tree_root_level
=
0
,
batch_size
=
batch_size
,
logits
=
logits
,
positions
=
positions
,
hidden_states
=
hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
)
# [batch_size, num_tree_tokens]
return
torch
.
cat
(
draft_token_ids_list
,
dim
=
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
# Early exit if there is only one draft token to be generated.
...
...
@@ -178,16 +230,15 @@ class EagleProposer:
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Currently FlashAttention
is the only backend that supports
#
multi-token
eagle spec decode. This is because the code below
# Currently
, only
FlashAttention
and TreeAttention support multi-token
# eagle spec decode. This is because the code below
# makes assumptions about attn_metadata attributes available.
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
assert
isinstance
(
attn_metadata
,
(
FlashAttentionMetadata
,
TreeAttentionMetadata
))
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
positions
=
target_positions
[
last_token_indices
]
hidden_states
=
hidden_states
[
last_token_indices
]
if
self
.
use_cuda_graph
and
\
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
...
...
@@ -196,7 +247,7 @@ class EagleProposer:
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
for
token_index
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
...
...
@@ -265,7 +316,20 @@ class EagleProposer:
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
None
)
# TODO(wenlong): get more than one token for tree attention
if
self
.
first_branching_level
==
token_index
+
1
:
# Branching has occurred. The remaining tokens are drafted
# using tree attention.
draft_token_ids_list
+=
self
.
propose_tree
(
tree_root_level
=
token_index
+
1
,
batch_size
=
batch_size
,
logits
=
logits
,
positions
=
positions
,
hidden_states
=
hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
)
# [batch_size, num_tree_tokens]
return
torch
.
cat
(
draft_token_ids_list
,
dim
=
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
...
...
@@ -273,6 +337,175 @@ class EagleProposer:
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
def
propose_tree
(
self
,
tree_root_level
:
int
,
batch_size
:
int
,
# [num_tokens, vocab_size]
logits
:
torch
.
Tensor
,
# [num_tokens]
positions
:
torch
.
Tensor
,
# [num_tokens, hidden_size]
hidden_states
:
torch
.
Tensor
,
common_attn_metadata
:
CommonAttentionMetadata
,
)
->
list
[
torch
.
Tensor
]:
tree_attn_metadata_builder
=
self
.
runner
.
attn_metadata_builders
[
0
]
assert
isinstance
(
tree_attn_metadata_builder
,
TreeAttentionMetadataBuilder
)
total_num_drafts
=
self
.
cu_drafts_per_level
[
tree_root_level
]
level_num_drafts
=
total_num_drafts
# Sample a draft token for each child at the tree root level.
num_children
=
self
.
child_drafts_per_level
[
tree_root_level
]
if
num_children
==
1
:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
).
view
(
batch_size
,
-
1
)
else
:
draft_token_ids
=
torch
.
topk
(
logits
,
num_children
,
dim
=-
1
).
indices
.
view
(
batch_size
,
-
1
)
draft_token_ids_list
=
[
draft_token_ids
]
draft_hidden_states
=
hidden_states
.
view
(
batch_size
,
1
,
-
1
)
# Initialize empty tensors for concatenation with the level outputs.
tree_input_ids
=
torch
.
empty
(
0
,
device
=
self
.
input_ids
.
device
,
dtype
=
self
.
input_ids
.
dtype
)
tree_positions
=
torch
.
empty
(
0
,
device
=
self
.
positions
.
device
,
dtype
=
self
.
positions
.
dtype
)
tree_hidden_states
=
torch
.
empty
(
0
,
device
=
self
.
hidden_states
.
device
,
dtype
=
self
.
hidden_states
.
dtype
)
# Precompute the draft token positions.
flattened_draft_positions
=
(
positions
.
view
(
batch_size
,
-
1
)
+
self
.
tree_draft_pos_offsets
[:
batch_size
,
:])
tree_depth
=
len
(
self
.
cu_drafts_per_level
)
for
level
in
range
(
tree_root_level
,
tree_depth
-
1
):
# Get draft positions for RoPE.
draft_positions
=
positions
+
(
level
+
1
)
exceeds_max_model_len
=
(
positions
+
total_num_drafts
)
>=
self
.
max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_draft_positions
=
torch
.
where
(
exceeds_max_model_len
,
0
,
draft_positions
,
)
if
level_num_drafts
>
1
:
# Repeat the positions for each draft at this level.
draft_positions
=
clamped_draft_positions
.
repeat_interleave
(
level_num_drafts
).
reshape
(
batch_size
,
-
1
)
if
num_children
>
1
:
# Repeat draft hidden states for each child.
draft_hidden_states
=
draft_hidden_states
.
repeat_interleave
(
num_children
,
dim
=
1
)
# Concatenate the draft tokens, positions, and hidden states.
tree_input_ids
=
torch
.
cat
([
tree_input_ids
,
draft_token_ids
],
dim
=
1
)
tree_positions
=
torch
.
cat
([
tree_positions
,
draft_positions
],
dim
=
1
)
tree_hidden_states
=
torch
.
cat
(
[
tree_hidden_states
,
draft_hidden_states
],
dim
=
1
)
# Build new attention metadata for the next level of drafts.
# This is necessary to support tree attention.
query_len
=
total_num_drafts
-
tree_root_level
common_attn_metadata
=
replace
(
common_attn_metadata
,
query_start_loc
=
query_len
*
self
.
arange
[:
batch_size
+
1
],
seq_lens
=
common_attn_metadata
.
seq_lens
+
level_num_drafts
,
num_actual_tokens
=
batch_size
*
query_len
,
max_query_len
=
query_len
,
)
attn_metadata
=
tree_attn_metadata_builder
.
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
tree_root_level
+
1
,
)
# Apply new attention metadata to all layers.
per_layer_attn_metadata
=
{}
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
# Consider max model length.
attn_metadata
.
max_seq_len
=
min
(
attn_metadata
.
max_seq_len
,
self
.
max_model_len
)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata
.
seq_lens
.
masked_fill_
(
exceeds_max_model_len
,
1
)
# Compute the slot mapping.
query_positions
=
flattened_draft_positions
[:,
level
:
level
+
query_len
]
block_numbers
=
query_positions
//
self
.
block_size
block_ids
=
attn_metadata
.
block_table
.
gather
(
dim
=
1
,
index
=
block_numbers
)
slot_mapping
=
(
block_ids
*
self
.
block_size
+
query_positions
%
self
.
block_size
)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping
[
exceeds_max_model_len
]
=
PADDING_SLOT_ID
attn_metadata
.
slot_mapping
=
slot_mapping
.
view
(
-
1
)
# Copy inputs to buffer for cudagraph.
num_tokens
=
attn_metadata
.
num_actual_tokens
input_ids
=
tree_input_ids
.
view
(
-
1
)
self
.
input_ids
[:
num_tokens
]
=
input_ids
self
.
positions
[:
num_tokens
]
=
tree_positions
.
view
(
-
1
)
self
.
hidden_states
[:
num_tokens
]
=
tree_hidden_states
.
view
(
num_tokens
,
-
1
)
if
self
.
use_cuda_graph
and
\
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
else
:
num_input_tokens
=
num_tokens
# Run the model.
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
):
last_hidden_states
,
hidden_states
=
self
.
model
(
input_ids
=
self
.
input_ids
[:
num_input_tokens
],
positions
=
self
.
positions
[:
num_input_tokens
],
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
inputs_embeds
=
None
,
)
# Get the output hidden states for the draft tokens.
draft_hidden_states
=
hidden_states
[:
num_tokens
].
view
(
batch_size
,
query_len
,
-
1
)[:,
-
level_num_drafts
:]
draft_last_hidden_states
=
last_hidden_states
[:
num_tokens
].
view
(
batch_size
,
query_len
,
-
1
)[:,
-
level_num_drafts
:]
# Get the output logits for the draft tokens.
logits
=
self
.
model
.
compute_logits
(
draft_last_hidden_states
.
reshape
(
batch_size
*
level_num_drafts
,
-
1
),
None
,
)
# Sample a draft token for each child at the next tree level.
num_children
=
self
.
child_drafts_per_level
[
level
+
1
]
if
num_children
==
1
:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
).
view
(
batch_size
,
-
1
)
else
:
draft_token_ids
=
torch
.
topk
(
logits
,
num_children
,
dim
=-
1
).
indices
.
view
(
batch_size
,
-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
# Update the # drafts counters for the next tree level.
level_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
-
total_num_drafts
total_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
return
draft_token_ids_list
def
prepare_inputs
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment