Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
000214c4
Unverified
Commit
000214c4
authored
Feb 10, 2026
by
Vadim Gimpelson
Committed by
GitHub
Feb 10, 2026
Browse files
[BUGFIX] Fix accuracy bugs in Qwen3-Next MTP (#34077)
Signed-off-by:
Vadim Gimpelson
<
vadim.gimpelson@gmail.com
>
parent
c5a66d16
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
4 deletions
+18
-4
vllm/v1/attention/backends/gdn_attn.py
vllm/v1/attention/backends/gdn_attn.py
+18
-4
No files found.
vllm/v1/attention/backends/gdn_attn.py
View file @
000214c4
...
...
@@ -208,7 +208,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_query_lens
=
query_lens
[
~
spec_sequence_masks
]
num_decodes
=
(
non_spec_query_lens
==
1
).
sum
().
item
()
num_prefills
=
non_spec_query_lens
.
size
(
0
)
-
num_decodes
# Exclude zero-length padded sequences from prefill count.
num_zero_len
=
(
non_spec_query_lens
==
0
).
sum
().
item
()
num_prefills
=
non_spec_query_lens
.
size
(
0
)
-
num_decodes
-
num_zero_len
num_decode_tokens
=
num_decodes
num_prefill_tokens
=
non_spec_query_lens
.
sum
().
item
()
-
num_decode_tokens
num_spec_decode_tokens
=
(
...
...
@@ -228,9 +230,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_token_indx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
query_start_loc
.
device
)
spec_state_indices_tensor
=
block_table_tensor
[:,
:
self
.
num_spec
+
1
]
# Filter by spec_sequence_masks to exclude padded sequences
spec_state_indices_tensor
=
block_table_tensor
[
spec_sequence_masks
,
:
self
.
num_spec
+
1
]
non_spec_state_indices_tensor
=
None
spec_query_start_loc
=
query_start_loc
# Padded sequences are always at the back, so the first
# num_spec_decodes + 1 entries of query_start_loc already
# contain the correct cumulative token counts.
spec_query_start_loc
=
query_start_loc
[:
num_spec_decodes
+
1
]
non_spec_query_start_loc
=
None
non_spec_query_start_loc_cpu
=
None
else
:
...
...
@@ -294,6 +302,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
else
:
has_initial_state
=
None
# Function code counted on either presency non-spec decode or spec decode,
# but not both.
assert
not
(
num_decodes
>
0
and
num_spec_decodes
>
0
),
(
f
"num_decodes:
{
num_decodes
}
, num_spec_decodes:
{
num_spec_decodes
}
"
)
# Prepare tensors for cudagraph
# Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
batch_size
=
m
.
num_actual_tokens
...
...
@@ -312,7 +326,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_state_indices_tensor
[
num_spec_decodes
:].
fill_
(
PAD_SLOT_ID
)
self
.
spec_sequence_masks
[:
num_spec_decodes
].
copy_
(
spec_sequence_masks
,
non_blocking
=
True
spec_sequence_masks
[:
num_spec_decodes
]
,
non_blocking
=
True
)
spec_sequence_masks
=
self
.
spec_sequence_masks
[:
batch_size
]
spec_sequence_masks
[
num_spec_decodes
:].
fill_
(
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment