Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d9a20fd2
Unverified
Commit
d9a20fd2
authored
Oct 20, 2025
by
Qiaolin Yu
Committed by
GitHub
Oct 21, 2025
Browse files
Use trtllm_mla decode kernel for draft extend in speculative decoding (#11664)
parent
b113c72e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
520 additions
and
18 deletions
+520
-18
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+348
-18
python/sglang/test/attention/test_trtllm_mla_backend.py
python/sglang/test/attention/test_trtllm_mla_backend.py
+172
-0
No files found.
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
d9a20fd2
...
...
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Optional, Union
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
...
...
@@ -48,6 +49,151 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
# compute the LCM with other padding constraints.
TRTLLM_BLOCK_CONSTRAINT
=
128
@
triton
.
jit
def
pad_draft_extend_query_kernel
(
q_ptr
,
# Input query tensor [total_seq_len, num_heads, head_dim]
padded_q_ptr
,
# Output padded query tensor [batch_size, max_seq_len, num_heads, head_dim]
seq_lens_q_ptr
,
# Sequence lengths for each sequence [batch_size]
cumsum_ptr
,
# Cumulative sum of accept lengths [batch_size + 1]
batch_size
,
max_seq_len
,
num_heads
,
head_dim
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""Triton kernel for padding draft extended query tensor with parallelized head and dim processing."""
# Use 3D program IDs: (batch_seq, head_block, dim_block)
batch_seq_pid
=
tl
.
program_id
(
0
)
head_pid
=
tl
.
program_id
(
1
)
dim_pid
=
tl
.
program_id
(
2
)
batch_id
=
batch_seq_pid
//
max_seq_len
seq_pos
=
batch_seq_pid
%
max_seq_len
if
batch_id
>=
batch_size
:
return
# Load accept length for this batch
seq_len
=
tl
.
load
(
seq_lens_q_ptr
+
batch_id
)
if
seq_pos
>=
seq_len
:
return
# Load cumulative sum to get start position in input tensor
input_start
=
tl
.
load
(
cumsum_ptr
+
batch_id
)
input_pos
=
input_start
+
seq_pos
# Calculate head and dim block ranges
head_start
=
head_pid
*
BLOCK_SIZE
head_end
=
tl
.
minimum
(
head_start
+
BLOCK_SIZE
,
num_heads
)
head_mask
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
<
(
head_end
-
head_start
)
dim_start
=
dim_pid
*
BLOCK_SIZE
dim_end
=
tl
.
minimum
(
dim_start
+
BLOCK_SIZE
,
head_dim
)
dim_mask
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
<
(
dim_end
-
dim_start
)
# Calculate input offset
input_offset
=
(
input_pos
*
num_heads
*
head_dim
+
(
head_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
))[:,
None
]
*
head_dim
+
(
dim_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
))[
None
,
:]
)
# Load data
data
=
tl
.
load
(
q_ptr
+
input_offset
,
mask
=
head_mask
[:,
None
]
&
dim_mask
[
None
,
:],
other
=
0.0
,
)
# Calculate output offset
output_offset
=
(
batch_id
*
max_seq_len
*
num_heads
*
head_dim
+
seq_pos
*
num_heads
*
head_dim
+
(
head_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
))[:,
None
]
*
head_dim
+
(
dim_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
))[
None
,
:]
)
# Store data
tl
.
store
(
padded_q_ptr
+
output_offset
,
data
,
mask
=
head_mask
[:,
None
]
&
dim_mask
[
None
,
:],
)
@
triton
.
jit
def
unpad_draft_extend_output_kernel
(
raw_out_ptr
,
# Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
output_ptr
,
# Output tensor (-1, tp_q_head_num, v_head_dim)
accept_length_ptr
,
# Accept lengths for each sequence [batch_size]
cumsum_ptr
,
# Cumulative sum of accept lengths [batch_size + 1]
batch_size
,
token_per_batch
,
tp_q_head_num
,
v_head_dim
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""Triton kernel for unpadding draft extended output tensor with parallelized head and dim processing."""
batch_seq_pid
=
tl
.
program_id
(
0
)
head_pid
=
tl
.
program_id
(
1
)
dim_pid
=
tl
.
program_id
(
2
)
batch_id
=
batch_seq_pid
//
token_per_batch
seq_pos
=
batch_seq_pid
%
token_per_batch
if
batch_id
>=
batch_size
:
return
# Load accept length for this batch
accept_len
=
tl
.
load
(
accept_length_ptr
+
batch_id
)
if
seq_pos
>=
accept_len
:
return
# Load cumulative sum to get start position in output tensor
output_start
=
tl
.
load
(
cumsum_ptr
+
batch_id
)
output_pos
=
output_start
+
seq_pos
# Calculate head and dim block ranges
head_start
=
head_pid
*
BLOCK_SIZE
head_end
=
tl
.
minimum
(
head_start
+
BLOCK_SIZE
,
tp_q_head_num
)
head_mask
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
<
(
head_end
-
head_start
)
dim_start
=
dim_pid
*
BLOCK_SIZE
dim_end
=
tl
.
minimum
(
dim_start
+
BLOCK_SIZE
,
v_head_dim
)
dim_mask
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
<
(
dim_end
-
dim_start
)
# Calculate input offset: (batch_id, seq_pos, head_id, dim_id)
input_offset
=
(
batch_id
*
token_per_batch
*
tp_q_head_num
*
v_head_dim
+
seq_pos
*
tp_q_head_num
*
v_head_dim
+
(
head_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
))[:,
None
]
*
v_head_dim
+
(
dim_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
))[
None
,
:]
)
# Load data
data
=
tl
.
load
(
raw_out_ptr
+
input_offset
,
mask
=
head_mask
[:,
None
]
&
dim_mask
[
None
,
:],
other
=
0.0
,
)
output_offset
=
(
output_pos
*
tp_q_head_num
*
v_head_dim
+
(
head_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
))[:,
None
]
*
v_head_dim
+
(
dim_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
))[
None
,
:]
)
# Store data
tl
.
store
(
output_ptr
+
output_offset
,
data
,
mask
=
head_mask
[:,
None
]
&
dim_mask
[
None
,
:],
)
global_zero_init_workspace_buffer
=
None
...
...
@@ -65,7 +211,11 @@ class TRTLLMMLADecodeMetadata:
"""Metadata for TRTLLM MLA decode operations."""
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
max_seq_len
:
Optional
[
int
]
=
None
max_seq_len_k
:
Optional
[
int
]
=
None
max_seq_len_q
:
Optional
[
int
]
=
None
sum_seq_lens_q
:
Optional
[
int
]
=
None
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_q
:
Optional
[
torch
.
Tensor
]
=
None
class
TRTLLMMLABackend
(
FlashInferMLAAttnBackend
):
...
...
@@ -120,6 +270,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# CUDA graph state
self
.
decode_cuda_graph_metadata
=
{}
self
.
decode_cuda_graph_kv_indices
=
None
self
.
padded_q_buffer
=
None
self
.
unpad_output_buffer
=
None
self
.
forward_prefill_metadata
:
Optional
[
TRTLLMMLAPrefillMetadata
]
=
None
self
.
forward_decode_metadata
:
Union
[
TRTLLMMLADecodeMetadata
,
None
]
=
None
...
...
@@ -203,6 +355,21 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self
.
decode_cuda_graph_kv_indices
=
torch
.
full
(
(
max_bs
,
max_blocks_per_seq
),
-
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
num_tokens_per_bs
=
max_num_tokens
//
max_bs
# Buffer for padded query: (max_bs, max_draft_tokens, num_q_heads, v_head_dim)
self
.
padded_q_buffer
=
torch
.
zeros
(
(
max_bs
,
num_tokens_per_bs
,
self
.
num_q_heads
,
self
.
kv_cache_dim
),
dtype
=
self
.
data_type
,
device
=
self
.
device
,
)
# Buffer for unpadded output: (max_num_tokens, num_q_heads, v_head_dim)
self
.
unpad_output_buffer
=
torch
.
zeros
(
(
max_num_tokens
,
self
.
num_q_heads
,
512
),
dtype
=
self
.
data_type
,
device
=
self
.
device
,
)
super
().
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
)
...
...
@@ -219,7 +386,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"""Initialize metadata for CUDA graph capture."""
# Delegate to parent for non-decode modes.
if
not
forward_mode
.
is_decode_or_idle
()
and
not
forward_mode
.
is_target_verify
():
if
(
not
forward_mode
.
is_decode_or_idle
()
and
not
forward_mode
.
is_target_verify
()
and
not
forward_mode
.
is_draft_extend
()
):
return
super
().
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
...
...
@@ -259,6 +430,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices
,
max_seq_len_val
,
)
if
forward_mode
.
is_draft_extend
():
num_tokens_per_bs
=
num_tokens
//
bs
metadata
.
max_seq_len_q
=
num_tokens_per_bs
+
1
metadata
.
sum_seq_lens_q
=
num_tokens_per_bs
*
bs
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
bs
*
num_tokens_per_bs
+
1
,
num_tokens_per_bs
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
metadata
.
seq_lens_q
=
torch
.
full
(
(
bs
,),
num_tokens_per_bs
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
)
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
self
.
forward_decode_metadata
=
metadata
...
...
@@ -275,7 +460,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
):
"""Replay CUDA graph with new inputs."""
# Delegate to parent for non-decode modes.
if
not
forward_mode
.
is_decode_or_idle
()
and
not
forward_mode
.
is_target_verify
():
if
(
not
forward_mode
.
is_decode_or_idle
()
and
not
forward_mode
.
is_target_verify
()
and
not
forward_mode
.
is_draft_extend
()
):
return
super
().
init_forward_metadata_replay_cuda_graph
(
bs
,
req_pool_indices
,
...
...
@@ -293,6 +482,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
if
forward_mode
.
is_draft_extend
():
accept_length
=
spec_info
.
accept_length
[:
bs
]
if
spec_info
.
accept_length_cpu
:
metadata
.
max_seq_len_q
=
max
(
spec_info
.
accept_length_cpu
[:
bs
])
metadata
.
sum_seq_lens_q
=
sum
(
spec_info
.
accept_length_cpu
[:
bs
])
else
:
metadata
.
max_seq_len_q
=
1
metadata
.
sum_seq_lens_q
=
bs
metadata
.
cu_seqlens_q
[
1
:].
copy_
(
torch
.
cumsum
(
accept_length
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
metadata
.
seq_lens_q
.
copy_
(
accept_length
)
# Update block indices for new sequences.
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
...
...
@@ -344,6 +546,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
elif
(
forward_batch
.
forward_mode
.
is_decode_or_idle
()
or
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
):
bs
=
forward_batch
.
batch_size
...
...
@@ -372,6 +575,23 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self
.
forward_decode_metadata
=
TRTLLMMLADecodeMetadata
(
block_kv_indices
,
max_seq_len_val
)
if
forward_batch
.
forward_mode
.
is_draft_extend
():
max_seq
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
sum_seq_lens_q
=
sum
(
forward_batch
.
extend_seq_lens_cpu
)
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
self
.
forward_decode_metadata
.
max_seq_len_q
=
max_seq_len_q
self
.
forward_decode_metadata
.
sum_seq_lens_q
=
sum_seq_lens_q
self
.
forward_decode_metadata
.
cu_seqlens_q
=
cu_seqlens_q
self
.
forward_decode_metadata
.
seq_lens_q
=
forward_batch
.
extend_seq_lens
forward_batch
.
decode_trtllm_mla_metadata
=
self
.
forward_decode_metadata
else
:
return
super
().
init_forward_metadata
(
forward_batch
)
...
...
@@ -457,6 +677,86 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
return
q_out
,
k_nope_out
,
k_rope_out
def
pad_draft_extend_query
(
self
,
q
:
torch
.
Tensor
,
padded_q
:
torch
.
Tensor
,
seq_lens_q
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Pad draft extended query using Triton kernel."""
batch_size
=
cu_seqlens_q
.
shape
[
0
]
-
1
max_seq_len_q
=
padded_q
.
shape
[
1
]
num_heads
=
padded_q
.
shape
[
2
]
head_dim
=
padded_q
.
shape
[
3
]
# Launch Triton kernel with 3D grid for parallelized head and dim processing
BLOCK_SIZE
=
64
num_head_blocks
=
triton
.
cdiv
(
num_heads
,
BLOCK_SIZE
)
num_dim_blocks
=
triton
.
cdiv
(
head_dim
,
BLOCK_SIZE
)
grid
=
(
batch_size
*
max_seq_len_q
,
num_head_blocks
,
num_dim_blocks
)
pad_draft_extend_query_kernel
[
grid
](
q_ptr
=
q
,
padded_q_ptr
=
padded_q
,
seq_lens_q_ptr
=
seq_lens_q
,
cumsum_ptr
=
cu_seqlens_q
,
batch_size
=
batch_size
,
max_seq_len
=
max_seq_len_q
,
num_heads
=
num_heads
,
head_dim
=
head_dim
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
padded_q
def
unpad_draft_extend_output
(
self
,
raw_out
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
seq_lens_q
:
torch
.
Tensor
,
sum_seq_lens_q
:
int
,
)
->
torch
.
Tensor
:
"""Unpad draft extended output using Triton kernel."""
# raw_out: (batch_size, token_per_batch, layer.tp_q_head_num, layer.v_head_dim)
batch_size
=
seq_lens_q
.
shape
[
0
]
token_per_batch
=
raw_out
.
shape
[
1
]
# max_seq_len
tp_q_head_num
=
raw_out
.
shape
[
2
]
# num_heads
v_head_dim
=
raw_out
.
shape
[
3
]
# head_dim
total_tokens
=
sum_seq_lens_q
# Check if we're in CUDA graph mode (buffers are pre-allocated)
if
self
.
unpad_output_buffer
is
not
None
:
# Use pre-allocated buffer for CUDA graph compatibility
output
=
self
.
unpad_output_buffer
[:
total_tokens
,
:,
:].
to
(
dtype
=
raw_out
.
dtype
)
else
:
# Dynamic allocation for non-CUDA graph mode
output
=
torch
.
empty
(
(
total_tokens
,
tp_q_head_num
,
v_head_dim
),
dtype
=
raw_out
.
dtype
,
device
=
raw_out
.
device
,
)
# Launch Triton kernel with 3D grid for parallelized head and dim processing
BLOCK_SIZE
=
64
num_head_blocks
=
triton
.
cdiv
(
tp_q_head_num
,
BLOCK_SIZE
)
num_dim_blocks
=
triton
.
cdiv
(
v_head_dim
,
BLOCK_SIZE
)
grid
=
(
batch_size
*
token_per_batch
,
num_head_blocks
,
num_dim_blocks
)
unpad_draft_extend_output_kernel
[
grid
](
raw_out_ptr
=
raw_out
,
output_ptr
=
output
,
accept_length_ptr
=
seq_lens_q
,
cumsum_ptr
=
cu_seqlens_q
,
batch_size
=
batch_size
,
token_per_batch
=
token_per_batch
,
tp_q_head_num
=
tp_q_head_num
,
v_head_dim
=
v_head_dim
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
output
[:
total_tokens
,
:,
:]
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
# q_nope
...
...
@@ -550,7 +850,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
block_tables
=
metadata
.
block_kv_indices
,
seq_lens
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
max_seq_len
=
metadata
.
max_seq_len
,
max_seq_len
=
metadata
.
max_seq_len
_k
,
bmm1_scale
=
bmm1_scale
,
)
...
...
@@ -571,11 +871,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
cos_sin_cache
:
Optional
[
torch
.
Tensor
]
=
None
,
is_neox
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
if
forward_batch
.
forward_mode
.
is_draft_extend
():
return
super
().
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
)
# TODO refactor to avoid code duplication
merge_query
=
q_rope
is
not
None
if
(
...
...
@@ -627,7 +922,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
if
forward_batch
.
forward_mode
.
is_target_verify
():
if
(
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
):
metadata
=
(
getattr
(
forward_batch
,
"decode_trtllm_mla_metadata"
,
None
)
or
self
.
forward_decode_metadata
...
...
@@ -635,7 +933,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
bs
=
forward_batch
.
batch_size
q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
kv_cache
=
k_cache
.
view
(
-
1
,
self
.
page_size
,
self
.
kv_cache_dim
).
unsqueeze
(
1
)
...
...
@@ -646,17 +943,42 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if
getattr
(
layer
,
"k_scale_float"
,
None
)
is
not
None
else
1.0
)
q
=
q
.
to
(
self
.
data_type
)
bmm1_scale
=
q_scale
*
k_scale
*
layer
.
scaling
if
forward_batch
.
forward_mode
.
is_target_verify
():
seq_lens
=
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
)
+
forward_batch
.
spec_info
.
draft_token_num
)
max_seq_len
=
metadata
.
max_seq_len
+
forward_batch
.
spec_info
.
draft_token_num
max_seq_len
=
(
metadata
.
max_seq_len_k
+
forward_batch
.
spec_info
.
draft_token_num
)
else
:
seq_lens
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
)
max_seq_len
=
metadata
.
max_seq_len_k
# Check if we're in CUDA graph mode (buffers are pre-allocated)
if
self
.
padded_q_buffer
is
not
None
:
# Use pre-allocated buffer for CUDA graph compatibility
padded_q
=
self
.
padded_q_buffer
[
:
bs
,
:
metadata
.
max_seq_len_q
,
:,
:
].
to
(
dtype
=
q
.
dtype
)
else
:
# Dynamic allocation for non-CUDA graph mode
padded_q
=
torch
.
zeros
(
bs
,
metadata
.
max_seq_len_q
,
layer
.
tp_q_head_num
,
layer
.
head_dim
,
dtype
=
q
.
dtype
,
device
=
q
.
device
,
)
q
=
self
.
pad_draft_extend_query
(
q
,
padded_q
,
metadata
.
seq_lens_q
,
metadata
.
cu_seqlens_q
)
# TODO may use `mla_rope_quantize_fp8` fusion
q
=
q
.
to
(
self
.
data_type
)
q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
assert
kv_cache
.
dtype
==
self
.
data_type
raw_out
=
flashinfer
.
decode
.
trtllm_batch_decode_with_kv_cache_mla
(
...
...
@@ -673,6 +995,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
)
# Reshape output directly without slicing
if
forward_batch
.
forward_mode
.
is_draft_extend
():
raw_out
=
self
.
unpad_draft_extend_output
(
raw_out
,
metadata
.
cu_seqlens_q
,
metadata
.
seq_lens_q
,
metadata
.
sum_seq_lens_q
,
)
output
=
raw_out
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
output
...
...
python/sglang/test/attention/test_trtllm_mla_backend.py
View file @
d9a20fd2
...
...
@@ -1263,6 +1263,178 @@ class TestTRTLLMMLA(CustomTestCase):
f
"Max diff:
{
(
out_trtllm
-
out_reference
).
abs
().
max
().
item
()
}
"
,
)
def
test_draft_extend_padding_unpadding_kernels
(
self
):
"""Test TRTLLM MLA Triton kernels: pad_draft_extend_query_kernel and unpad_draft_extend_output_kernel."""
# Import the kernels
from
sglang.srt.layers.attention.trtllm_mla_backend
import
(
pad_draft_extend_query_kernel
,
unpad_draft_extend_output_kernel
,
)
def
_create_test_data
(
self
,
batch_size
,
max_seq_len
,
num_heads
,
head_dim
,
dtype
=
torch
.
float32
):
"""Create test data for kernel testing."""
device
=
torch
.
device
(
"cuda"
)
# Create sequence lengths (varying lengths for each batch)
seq_lens
=
torch
.
randint
(
1
,
max_seq_len
+
1
,
(
batch_size
,),
device
=
device
,
dtype
=
torch
.
int32
)
# Create cumulative sequence lengths
cum_seq_lens
=
torch
.
zeros
(
batch_size
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cum_seq_lens
[
1
:]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
# Create input query tensor (flattened format)
total_tokens
=
cum_seq_lens
[
-
1
].
item
()
q_input
=
torch
.
randn
(
total_tokens
,
num_heads
,
head_dim
,
device
=
device
,
dtype
=
dtype
)
# Create padded query tensor (batch format)
padded_q
=
torch
.
zeros
(
batch_size
,
max_seq_len
,
num_heads
,
head_dim
,
device
=
device
,
dtype
=
dtype
)
return
q_input
,
padded_q
,
seq_lens
,
cum_seq_lens
def
_create_test_output_data
(
self
,
batch_size
,
token_per_batch
,
tp_q_head_num
,
v_head_dim
,
dtype
=
torch
.
float32
,
):
"""Create test data for unpad kernel testing."""
device
=
torch
.
device
(
"cuda"
)
# Create accept lengths (varying lengths for each batch)
accept_lengths
=
torch
.
randint
(
1
,
token_per_batch
+
1
,
(
batch_size
,),
device
=
device
,
dtype
=
torch
.
int32
)
# Create cumulative accept lengths
cum_accept_lengths
=
torch
.
zeros
(
batch_size
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cum_accept_lengths
[
1
:]
=
torch
.
cumsum
(
accept_lengths
,
dim
=
0
)
# Create raw output tensor (batch format)
raw_out
=
torch
.
randn
(
batch_size
,
token_per_batch
,
tp_q_head_num
,
v_head_dim
,
device
=
device
,
dtype
=
dtype
,
)
# Create output tensor (flattened format)
total_tokens
=
cum_accept_lengths
[
-
1
].
item
()
output
=
torch
.
empty
(
total_tokens
,
tp_q_head_num
,
v_head_dim
,
device
=
device
,
dtype
=
dtype
)
return
raw_out
,
output
,
accept_lengths
,
cum_accept_lengths
# Test 1: pad_draft_extend_query_kernel basic functionality
with
self
.
subTest
(
test
=
"pad_kernel_basic"
):
batch_size
=
4
max_seq_len
=
8
num_heads
=
16
head_dim
=
64
q_input
,
padded_q
,
seq_lens
,
cum_seq_lens
=
_create_test_data
(
self
,
batch_size
,
max_seq_len
,
num_heads
,
head_dim
)
# Launch kernel
BLOCK_SIZE
=
64
grid
=
(
batch_size
*
max_seq_len
,)
pad_draft_extend_query_kernel
[
grid
](
q_ptr
=
q_input
,
padded_q_ptr
=
padded_q
,
seq_lens_q_ptr
=
seq_lens
,
cumsum_ptr
=
cum_seq_lens
,
batch_size
=
batch_size
,
max_seq_len
=
max_seq_len
,
num_heads
=
num_heads
,
head_dim
=
head_dim
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
# Verify the padding worked correctly
for
i
in
range
(
batch_size
):
seq_len
=
seq_lens
[
i
].
item
()
# Check that valid positions are copied correctly
for
pos
in
range
(
seq_len
):
input_start
=
cum_seq_lens
[
i
].
item
()
input_pos
=
input_start
+
pos
# Compare input and output for valid positions
input_data
=
q_input
[
input_pos
]
output_data
=
padded_q
[
i
,
pos
]
torch
.
testing
.
assert_close
(
input_data
,
output_data
,
rtol
=
1e-5
,
atol
=
1e-6
)
# Check that invalid positions are zero
for
pos
in
range
(
seq_len
,
max_seq_len
):
output_data
=
padded_q
[
i
,
pos
]
self
.
assertTrue
(
torch
.
allclose
(
output_data
,
torch
.
zeros_like
(
output_data
)),
f
"Position
{
pos
}
in batch
{
i
}
should be zero"
,
)
# Test 2: unpad_draft_extend_output_kernel basic functionality
with
self
.
subTest
(
test
=
"unpad_kernel_basic"
):
batch_size
=
4
token_per_batch
=
8
tp_q_head_num
=
16
v_head_dim
=
64
raw_out
,
output
,
accept_lengths
,
cum_accept_lengths
=
(
_create_test_output_data
(
self
,
batch_size
,
token_per_batch
,
tp_q_head_num
,
v_head_dim
)
)
# Launch kernel
BLOCK_SIZE
=
64
grid
=
(
batch_size
*
token_per_batch
,)
unpad_draft_extend_output_kernel
[
grid
](
raw_out_ptr
=
raw_out
,
output_ptr
=
output
,
accept_length_ptr
=
accept_lengths
,
cumsum_ptr
=
cum_accept_lengths
,
batch_size
=
batch_size
,
token_per_batch
=
token_per_batch
,
tp_q_head_num
=
tp_q_head_num
,
v_head_dim
=
v_head_dim
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
# Verify the unpadding worked correctly
for
i
in
range
(
batch_size
):
accept_len
=
accept_lengths
[
i
].
item
()
output_start
=
cum_accept_lengths
[
i
].
item
()
# Check that valid positions are copied correctly
for
pos
in
range
(
accept_len
):
input_data
=
raw_out
[
i
,
pos
]
output_data
=
output
[
output_start
+
pos
]
torch
.
testing
.
assert_close
(
input_data
,
output_data
,
rtol
=
1e-5
,
atol
=
1e-6
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment