Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d94e3026
Unverified
Commit
d94e3026
authored
Aug 13, 2025
by
Giancarlo Delfin
Committed by
GitHub
Aug 13, 2025
Browse files
[V1] Add tree drafting tests for eagle spec decoding (#22705)
Signed-off-by:
Giancarlo Delfin
<
gdelfin@meta.com
>
parent
3f52738d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
178 additions
and
55 deletions
+178
-55
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+157
-3
tests/v1/spec_decode/test_max_len.py
tests/v1/spec_decode/test_max_len.py
+0
-6
vllm/v1/attention/backends/tree_attn.py
vllm/v1/attention/backends/tree_attn.py
+3
-3
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+18
-43
No files found.
tests/v1/spec_decode/test_eagle.py
View file @
d94e3026
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
unittest
import
mock
from
unittest
import
mock
import
pytest
import
pytest
...
@@ -23,7 +24,11 @@ eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
...
@@ -23,7 +24,11 @@ eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir
=
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
eagle3_dir
=
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def
_create_proposer
(
method
:
str
,
k
:
int
)
->
EagleProposer
:
def
_create_proposer
(
method
:
str
,
num_speculative_tokens
:
int
,
speculative_token_tree
:
Optional
[
list
[
tuple
[
int
]]]
=
None
,
)
->
EagleProposer
:
model_config
=
ModelConfig
(
model
=
model_dir
,
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
runner
=
"generate"
,
max_model_len
=
100
)
max_model_len
=
100
)
...
@@ -31,12 +36,18 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
...
@@ -31,12 +36,18 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
# Choose model directory based on method
# Choose model directory based on method
draft_model_dir
=
eagle_dir
if
method
==
"eagle"
else
eagle3_dir
draft_model_dir
=
eagle_dir
if
method
==
"eagle"
else
eagle3_dir
spec_token_tree_str
=
None
if
speculative_token_tree
is
not
None
:
assert
num_speculative_tokens
==
len
(
speculative_token_tree
)
spec_token_tree_str
=
str
(
speculative_token_tree
)
speculative_config
=
SpeculativeConfig
(
speculative_config
=
SpeculativeConfig
(
target_model_config
=
model_config
,
target_model_config
=
model_config
,
target_parallel_config
=
ParallelConfig
(),
target_parallel_config
=
ParallelConfig
(),
model
=
draft_model_dir
,
model
=
draft_model_dir
,
method
=
method
,
method
=
method
,
num_speculative_tokens
=
k
,
num_speculative_tokens
=
num_speculative_tokens
,
speculative_token_tree
=
spec_token_tree_str
,
)
)
vllm_config
=
VllmConfig
(
vllm_config
=
VllmConfig
(
...
@@ -189,7 +200,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
...
@@ -189,7 +200,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
target_model
.
lm_head
=
mock
.
MagicMock
()
target_model
.
lm_head
=
mock
.
MagicMock
()
# Create proposer using the helper function
# Create proposer using the helper function
proposer
=
_create_proposer
(
method
,
k
=
8
)
proposer
=
_create_proposer
(
method
,
num_speculative_tokens
=
8
)
# Call the method under test
# Call the method under test
proposer
.
load_model
(
target_model
)
proposer
.
load_model
(
target_model
)
...
@@ -226,6 +237,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
...
@@ -226,6 +237,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
pytest
.
skip
(
"TRITON_ATTN_VLLM_V1 does not support "
pytest
.
skip
(
"TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform"
)
"multi-token eagle spec decode on current platform"
)
if
(
attn_backend
==
"TREE_ATTN"
):
pytest
.
skip
(
"TREE_ATTN is tested separately in test_propose_tree"
"because it requires special input mocking."
)
if
attn_backend
==
"FLASH_ATTN_VLLM_V1"
and
current_platform
.
is_rocm
():
if
attn_backend
==
"FLASH_ATTN_VLLM_V1"
and
current_platform
.
is_rocm
():
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
...
@@ -378,3 +393,142 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
...
@@ -378,3 +393,142 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
# Verify all tokens match our expectations
# Verify all tokens match our expectations
assert
torch
.
equal
(
result
,
expected_tokens
)
assert
torch
.
equal
(
result
,
expected_tokens
)
@
pytest
.
mark
.
parametrize
(
"spec_token_tree"
,
[
[(
0
,
)],
# A single token
[(
0
,
),
(
0
,
0
),
(
0
,
0
,
0
)],
# Chain
[(
0
,
),
(
1
,
),
(
2
,
)],
# Parallel
[(
0
,
),
(
1
,
),
(
2
,
),
(
0
,
0
),
(
0
,
1
),
(
1
,
0
),
(
1
,
1
),
(
2
,
0
),
(
2
,
1
)],
# Tree
])
def
test_propose_tree
(
spec_token_tree
):
# Get GPU device.
device
=
torch
.
device
(
current_platform
.
device_type
)
# Setup test parameters.
batch_size
=
2
seq_len_1
=
5
seq_len_2
=
3
total_tokens
=
seq_len_1
+
seq_len_2
vocab_size
=
100
seq_lens
=
[
seq_len_1
,
seq_len_2
]
num_speculative_tokens
=
len
(
spec_token_tree
)
# Create proposer first so we can use its actual hidden_size.
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
,
speculative_token_tree
=
spec_token_tree
)
# Get the hidden_size from the proposer to ensure consistency.
hidden_size
=
proposer
.
hidden_size
# Helper to create deterministic logits that will produce specific tokens
def
create_deterministic_logits
(
token_ids
,
k
:
int
):
logits
=
torch
.
full
((
batch_size
,
vocab_size
),
-
100.0
,
device
=
device
)
for
i
,
token_id
in
enumerate
(
token_ids
):
# Assign decreasing values to the k, consecutive, tokens.
for
j
in
range
(
k
):
logits
[
i
,
token_id
+
j
]
=
100.0
-
j
return
logits
# Mock a model that returns deterministic logits.
base_token_ids
=
torch
.
tensor
([
42
,
60
],
dtype
=
torch
.
int64
,
device
=
device
)
# Skip loading the model and replace it with a mock that returns
# deterministic outputs.
model_mock
=
mock
.
MagicMock
()
# Mock the model forward calls.
forward_returns
=
[(
torch
.
zeros
(
total_tokens
,
hidden_size
,
device
=
device
),
torch
.
zeros
(
total_tokens
,
hidden_size
,
device
=
device
))]
for
cu_num_drafts
in
proposer
.
cu_drafts_per_level
:
h_logits
=
torch
.
zeros
(
batch_size
*
cu_num_drafts
,
hidden_size
,
device
=
device
)
h_states
=
torch
.
zeros
(
batch_size
*
cu_num_drafts
,
hidden_size
,
device
=
device
)
forward_returns
.
append
((
h_logits
,
h_states
))
model_mock
.
side_effect
=
forward_returns
# Mock the compute_logits calls.
cu_num_drafts_tensor
=
torch
.
tensor
([
0
]
+
proposer
.
cu_drafts_per_level
,
dtype
=
torch
.
int32
,
device
=
device
)
logits_returns
=
[]
for
level
,
num_children
in
enumerate
(
proposer
.
child_drafts_per_level
):
token_ids
=
base_token_ids
+
cu_num_drafts_tensor
[
level
]
level_num_drafts
=
cu_num_drafts_tensor
[
level
+
1
]
-
cu_num_drafts_tensor
[
level
]
level_logits
=
[]
for
i
in
range
(
level_num_drafts
//
num_children
):
level_logits
.
append
(
create_deterministic_logits
(
token_ids
+
i
*
num_children
,
num_children
))
logits_returns
.
append
(
torch
.
stack
(
level_logits
,
dim
=
1
))
model_mock
.
compute_logits
.
side_effect
=
logits_returns
# Assign the mock to the proposer
proposer
.
model
=
model_mock
# Assign draft attn_layer_names since load_model is not invoked
proposer
.
attn_layer_names
=
[
"layer.0"
]
# Get the tree attention metadata builder.
attn_metadata_builder_cls
,
_
=
get_attention_backend
(
_Backend
.
TREE_ATTN
)
attn_metadata_builder
=
attn_metadata_builder_cls
(
kv_cache_spec
=
create_standard_kv_cache_spec
(
proposer
.
vllm_config
),
layer_names
=
proposer
.
attn_layer_names
,
vllm_config
=
proposer
.
vllm_config
,
device
=
device
,
)
# Mock runner for attention metadata building.
proposer
.
runner
=
mock
.
MagicMock
()
proposer
.
runner
.
attn_groups
.
append
([
mock
.
MagicMock
()])
proposer
.
runner
.
attn_groups
[
0
][
0
].
metadata_builder
=
attn_metadata_builder
# Setup inputs for the proposer.
target_token_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
total_tokens
,
),
device
=
device
)
target_positions
=
torch
.
cat
([
torch
.
arange
(
seq_len_1
,
device
=
device
),
torch
.
arange
(
seq_len_2
,
device
=
device
)
])
target_hidden_states
=
torch
.
randn
(
total_tokens
,
hidden_size
,
device
=
device
)
next_token_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
device
)
batch_spec
=
BatchSpec
(
seq_lens
=
seq_lens
,
query_lens
=
seq_lens
,
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
=
16
,
device
=
device
,
)
sampling_metadata
=
mock
.
MagicMock
()
# Propose draft tokens.
result
=
proposer
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
next_token_ids
=
next_token_ids
,
common_attn_metadata
=
common_attn_metadata
,
sampling_metadata
=
sampling_metadata
)
assert
result
.
shape
==
(
batch_size
,
num_speculative_tokens
)
# The tokens are expected to be consecutive integers starting
# from the base token IDs.
expected_tokens
=
base_token_ids
[:,
None
]
+
torch
.
arange
(
num_speculative_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
# Verify that the draft tokens match our expectations.
assert
torch
.
equal
(
result
,
expected_tokens
)
tests/v1/spec_decode/test_max_len.py
View file @
d94e3026
...
@@ -39,12 +39,6 @@ def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
...
@@ -39,12 +39,6 @@ def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens
:
int
,
attn_backend
:
str
):
num_speculative_tokens
:
int
,
attn_backend
:
str
):
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
if
attn_backend
==
"TREE_ATTN"
and
num_speculative_tokens
>
1
:
# TREE_ATTN fails the test with multi-token spec decode
# TODO: Investigate why
pytest
.
skip
(
"TREE_ATTN fails the test"
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
(
attn_backend
==
"TRITON_ATTN_VLLM_V1"
if
(
attn_backend
==
"TRITON_ATTN_VLLM_V1"
...
...
vllm/v1/attention/backends/tree_attn.py
View file @
d94e3026
...
@@ -236,9 +236,9 @@ class TreeAttentionMetadataBuilder(
...
@@ -236,9 +236,9 @@ class TreeAttentionMetadataBuilder(
# Use prefill for drafting at the root level.
# Use prefill for drafting at the root level.
self
.
tree_attn_bias
=
torch
.
empty
(
0
)
self
.
tree_attn_bias
=
torch
.
empty
(
0
)
else
:
else
:
# Slice the tree attention bias for drafting.
# Slice the tree attention bias for drafting.
Exclude
query_len
=
common_attn_metadata
.
max_query_len
# the root level.
start
,
end
=
draft_index
,
draft_index
+
query_len
start
,
end
=
1
,
1
+
common_attn_metadata
.
max_
query_len
self
.
tree_attn_bias
=
self
.
tree_attn_bias
[
start
:
end
,
self
.
tree_attn_bias
=
self
.
tree_attn_bias
[
start
:
end
,
start
:
end
].
contiguous
()
start
:
end
].
contiguous
()
...
...
vllm/v1/spec_decode/eagle.py
View file @
d94e3026
...
@@ -113,13 +113,6 @@ class EagleProposer:
...
@@ -113,13 +113,6 @@ class EagleProposer:
num_drafts_per_level
[
level
])
num_drafts_per_level
[
level
])
self
.
child_drafts_per_level
.
append
(
num_drafts_per_level
[
level
]
//
self
.
child_drafts_per_level
.
append
(
num_drafts_per_level
[
level
]
//
num_drafts_per_level
[
level
-
1
])
num_drafts_per_level
[
level
-
1
])
# Find the first level where the tree branches off into one or more
# children.
self
.
first_branching_level
=
None
for
level
in
range
(
tree_depth
):
if
self
.
cu_drafts_per_level
[
level
]
>
level
+
1
:
self
.
first_branching_level
=
level
break
# Precompute draft position offsets in flattened tree.
# Precompute draft position offsets in flattened tree.
self
.
tree_draft_pos_offsets
=
torch
.
arange
(
self
.
tree_draft_pos_offsets
=
torch
.
arange
(
1
,
1
,
...
@@ -209,11 +202,10 @@ class EagleProposer:
...
@@ -209,11 +202,10 @@ class EagleProposer:
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
positions
=
target_positions
[
last_token_indices
]
positions
=
target_positions
[
last_token_indices
]
hidden_states
=
hidden_states
[
last_token_indices
]
hidden_states
=
hidden_states
[
last_token_indices
]
if
self
.
first_branching_level
==
0
:
# Branching has occurred at the root level. Draft using tree
if
isinstance
(
attn_metadata
,
TreeAttentionMetadata
):
# attention.
#
Draft using tree
attention.
draft_token_ids_list
=
self
.
propose_tree
(
draft_token_ids_list
=
self
.
propose_tree
(
tree_root_level
=
0
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
logits
=
logits
,
logits
=
logits
,
positions
=
positions
,
positions
=
positions
,
...
@@ -242,11 +234,10 @@ class EagleProposer:
...
@@ -242,11 +234,10 @@ class EagleProposer:
(
TritonAttentionMetadata
,
AiterFlashAttentionMetadata
,
(
TritonAttentionMetadata
,
AiterFlashAttentionMetadata
,
FlashAttentionMetadata
))
FlashAttentionMetadata
))
else
:
else
:
# Currently, only FlashAttention and TreeAttention support
# Currently, only FlashAttention supports multi-token eagle spec
# multi-token eagle spec decode. This is because the code below
# decode. This is because the code below makes assumptions about
# makes assumptions about attn_metadata attributes available.
# attn_metadata attributes available.
assert
isinstance
(
attn_metadata
,
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
(
FlashAttentionMetadata
,
TreeAttentionMetadata
))
# Generate the remaining draft tokens.
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
draft_token_ids_list
=
[
draft_token_ids
]
...
@@ -259,7 +250,7 @@ class EagleProposer:
...
@@ -259,7 +250,7 @@ class EagleProposer:
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
for
token_index
in
range
(
self
.
num_speculative_tokens
-
1
):
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
# tensor.argmax() returns int64 by default.
...
@@ -327,21 +318,6 @@ class EagleProposer:
...
@@ -327,21 +318,6 @@ class EagleProposer:
hidden_states
=
hidden_states
[:
batch_size
]
hidden_states
=
hidden_states
[:
batch_size
]
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
None
)
None
)
if
self
.
first_branching_level
==
token_index
+
1
:
# Branching has occurred. The remaining tokens are drafted
# using tree attention.
draft_token_ids_list
+=
self
.
propose_tree
(
tree_root_level
=
token_index
+
1
,
batch_size
=
batch_size
,
logits
=
logits
,
positions
=
positions
,
hidden_states
=
hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
)
# [batch_size, num_tree_tokens]
return
torch
.
cat
(
draft_token_ids_list
,
dim
=
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
draft_token_ids_list
.
append
(
draft_token_ids
)
...
@@ -351,7 +327,6 @@ class EagleProposer:
...
@@ -351,7 +327,6 @@ class EagleProposer:
def
propose_tree
(
def
propose_tree
(
self
,
self
,
tree_root_level
:
int
,
batch_size
:
int
,
batch_size
:
int
,
# [num_tokens, vocab_size]
# [num_tokens, vocab_size]
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
...
@@ -366,10 +341,10 @@ class EagleProposer:
...
@@ -366,10 +341,10 @@ class EagleProposer:
assert
isinstance
(
tree_attn_metadata_builder
,
assert
isinstance
(
tree_attn_metadata_builder
,
TreeAttentionMetadataBuilder
)
TreeAttentionMetadataBuilder
)
total_num_drafts
=
self
.
cu_drafts_per_level
[
tree_root_level
]
total_num_drafts
=
self
.
cu_drafts_per_level
[
0
]
level_num_drafts
=
total_num_drafts
level_num_drafts
=
total_num_drafts
# Sample a draft token for each child at the tree root level.
# Sample a draft token for each child at the tree root level.
num_children
=
self
.
child_drafts_per_level
[
tree_root_level
]
num_children
=
self
.
child_drafts_per_level
[
0
]
if
num_children
==
1
:
if
num_children
==
1
:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
).
view
(
batch_size
,
-
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
).
view
(
batch_size
,
-
1
)
else
:
else
:
...
@@ -393,22 +368,23 @@ class EagleProposer:
...
@@ -393,22 +368,23 @@ class EagleProposer:
positions
.
view
(
batch_size
,
-
1
)
+
positions
.
view
(
batch_size
,
-
1
)
+
self
.
tree_draft_pos_offsets
[:
batch_size
,
:])
self
.
tree_draft_pos_offsets
[:
batch_size
,
:])
tree_depth
=
len
(
self
.
cu_drafts_per_level
)
tree_depth
=
len
(
self
.
cu_drafts_per_level
)
for
level
in
range
(
tree_root_level
,
tree_depth
-
1
):
for
level
in
range
(
tree_depth
-
1
):
# Get draft positions for RoPE.
# Get draft positions for RoPE.
draft_positions
=
positions
+
(
level
+
1
)
draft_positions
=
positions
+
(
level
+
1
)
exceeds_max_model_len
=
(
positions
+
exceeds_max_model_len
=
(
positions
+
total_num_drafts
)
>=
self
.
max_model_len
total_num_drafts
)
>=
self
.
max_model_len
# Mask out the position ids that exceed the max model length.
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
# Otherwise, we may get out-of-range error in RoPE.
clamped_
draft_positions
=
torch
.
where
(
draft_positions
=
torch
.
where
(
exceeds_max_model_len
,
exceeds_max_model_len
,
0
,
0
,
draft_positions
,
draft_positions
,
)
).
view
(
batch_size
,
-
1
)
if
level_num_drafts
>
1
:
if
level_num_drafts
>
1
:
# Repeat the positions for each draft at this level.
# Repeat the positions for each draft at this level.
draft_positions
=
clamped_
draft_positions
.
repeat_interleave
(
draft_positions
=
draft_positions
.
repeat_interleave
(
level_num_drafts
).
reshape
(
batch_size
,
-
1
)
level_num_drafts
,
dim
=
1
)
if
num_children
>
1
:
if
num_children
>
1
:
# Repeat draft hidden states for each child.
# Repeat draft hidden states for each child.
...
@@ -425,7 +401,7 @@ class EagleProposer:
...
@@ -425,7 +401,7 @@ class EagleProposer:
# Build new attention metadata for the next level of drafts.
# Build new attention metadata for the next level of drafts.
# This is necessary to support tree attention.
# This is necessary to support tree attention.
query_len
=
total_num_drafts
-
tree_root_level
query_len
=
total_num_drafts
common_attn_metadata
=
replace
(
common_attn_metadata
=
replace
(
common_attn_metadata
,
common_attn_metadata
,
query_start_loc
=
query_len
*
self
.
arange
[:
batch_size
+
1
],
query_start_loc
=
query_len
*
self
.
arange
[:
batch_size
+
1
],
...
@@ -435,7 +411,7 @@ class EagleProposer:
...
@@ -435,7 +411,7 @@ class EagleProposer:
)
)
attn_metadata
=
tree_attn_metadata_builder
.
build_for_drafting
(
attn_metadata
=
tree_attn_metadata_builder
.
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
tree_root_
level
+
1
,
draft_index
=
level
+
1
,
)
)
# Apply new attention metadata to all layers.
# Apply new attention metadata to all layers.
...
@@ -516,7 +492,6 @@ class EagleProposer:
...
@@ -516,7 +492,6 @@ class EagleProposer:
level_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
level_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
-
total_num_drafts
1
]
-
total_num_drafts
total_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
total_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
return
draft_token_ids_list
return
draft_token_ids_list
def
prepare_inputs
(
def
prepare_inputs
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment