Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
05c19485
Unverified
Commit
05c19485
authored
Sep 24, 2025
by
Wei Wei
Committed by
GitHub
Sep 24, 2025
Browse files
[Kernel] Support DCP for Triton backend (#25132)
Signed-off-by:
Wei Wei
<
wwei6@meta.com
>
parent
52d0cb84
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
8 deletions
+30
-8
tests/kernels/attention/test_triton_decode_attention.py
tests/kernels/attention/test_triton_decode_attention.py
+5
-0
vllm/attention/ops/triton_decode_attention.py
vllm/attention/ops/triton_decode_attention.py
+17
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+1
-1
vllm/v1/attention/backends/mla/triton_mla.py
vllm/v1/attention/backends/mla/triton_mla.py
+7
-5
No files found.
tests/kernels/attention/test_triton_decode_attention.py
View file @
05c19485
...
...
@@ -46,6 +46,8 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
lse
=
torch
.
zeros
(
B
,
H_Q
,
dtype
=
dtype
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,
),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
...
...
@@ -60,6 +62,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
k_buffer
,
v_buffer
,
o
,
lse
,
req_to_token
,
b_seq_len
,
attn_logits
,
...
...
@@ -72,12 +75,14 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
v_buffer
=
v_buffer
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_V
)
o1
=
torch
.
zeros_like
(
o
)
lse1
=
torch
.
zeros_like
(
lse
)
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o1
,
lse1
,
req_to_page
,
b_seq_len
,
attn_logits
,
...
...
vllm/attention/ops/triton_decode_attention.py
View file @
05c19485
...
...
@@ -474,12 +474,14 @@ def _decode_grouped_att_m_fwd(
def
_fwd_kernel_stage2
(
Mid_O
,
o
,
lse
,
B_Seqlen
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_obs
,
stride_oh
,
stride_lse_bs
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
...
...
@@ -525,12 +527,18 @@ def _fwd_kernel_stage2(
acc
/
e_sum
,
mask
=
mask_d
,
)
lse_val
=
e_max
+
tl
.
log
(
e_sum
)
tl
.
store
(
lse
+
cur_batch
*
stride_lse_bs
+
cur_head
,
lse_val
,
)
def
_decode_softmax_reducev_fwd
(
logits
,
q
,
o
,
lse
,
v_buffer
,
b_seq_len
,
num_kv_splits
,
...
...
@@ -555,12 +563,14 @@ def _decode_softmax_reducev_fwd(
_fwd_kernel_stage2
[
grid
](
logits
,
o
,
lse
,
b_seq_len
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
lse
.
stride
(
0
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
...
...
@@ -575,6 +585,7 @@ def decode_attention_fwd_normal(
k_buffer
,
v_buffer
,
o
,
lse
,
req_to_token
,
b_seq_len
,
attn_logits
,
...
...
@@ -595,7 +606,7 @@ def decode_attention_fwd_normal(
page_size
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
lse
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
...
...
@@ -604,6 +615,7 @@ def decode_attention_fwd_grouped(
k_buffer
,
v_buffer
,
o
,
lse
,
req_to_token
,
b_seq_len
,
attn_logits
,
...
...
@@ -624,7 +636,7 @@ def decode_attention_fwd_grouped(
page_size
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
lse
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
...
...
@@ -633,6 +645,7 @@ def decode_attention_fwd(
k_buffer
,
v_buffer
,
o
,
lse
,
req_to_token
,
b_seq_len
,
attn_logits
,
...
...
@@ -651,6 +664,7 @@ def decode_attention_fwd(
k_buffer
,
v_buffer
,
o
,
lse
,
req_to_token
,
b_seq_len
,
attn_logits
,
...
...
@@ -666,6 +680,7 @@ def decode_attention_fwd(
k_buffer
,
v_buffer
,
o
,
lse
,
req_to_token
,
b_seq_len
,
attn_logits
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
05c19485
...
...
@@ -685,7 +685,7 @@ class DeepseekV2DecoderLayer(nn.Module):
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
residual
=
hidden_states
.
clone
()
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
...
...
vllm/v1/attention/backends/mla/triton_mla.py
View file @
05c19485
...
...
@@ -32,6 +32,7 @@ class TritonMLABackend(MLACommonBackend):
class
TritonMLAImpl
(
MLACommonImpl
[
MLACommonMetadata
]):
can_return_lse_for_decode
:
bool
=
True
def
__init__
(
self
,
...
...
@@ -139,19 +140,20 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
assert
isinstance
(
q
,
torch
.
Tensor
)
B
=
q
.
shape
[
0
]
q_num_heads
=
q
.
shape
[
1
]
o
=
torch
.
zeros
(
B
,
self
.
num_heads
,
q_
num_heads
,
self
.
kv_lora_rank
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
lse
=
torch
.
zeros
(
B
,
q_num_heads
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
num_kv_splits
=
4
# TODO: heuristic
# TODO(lucas) Allocate ahead of time
attn_logits
=
torch
.
empty
(
(
B
,
self
.
num_heads
,
q_
num_heads
,
num_kv_splits
,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
...
...
@@ -167,9 +169,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
PAGE_SIZE
=
kv_c_and_k_pe_cache
.
size
(
1
)
# Run MQA
decode_attention_fwd
(
q
,
kv_c_and_k_pe_cache
,
kv_c_cache
,
o
,
decode_attention_fwd
(
q
,
kv_c_and_k_pe_cache
,
kv_c_cache
,
o
,
lse
,
attn_metadata
.
decode
.
block_table
,
attn_metadata
.
decode
.
seq_lens
,
attn_logits
,
num_kv_splits
,
self
.
scale
,
PAGE_SIZE
)
return
o
,
Non
e
return
o
,
ls
e
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment