Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a1257fd1
Unverified
Commit
a1257fd1
authored
Mar 12, 2026
by
grimulkan
Committed by
GitHub
Mar 12, 2026
Browse files
[Kernel] Add FP8 KV cache support to Triton MLA decode attention (#34597)
Signed-off-by:
grimulkan
<
grimulkan@gmail.com
>
parent
abcffbba
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
192 additions
and
8 deletions
+192
-8
docs/design/attention_backends.md
docs/design/attention_backends.md
+1
-1
tests/kernels/attention/test_triton_decode_attention.py
tests/kernels/attention/test_triton_decode_attention.py
+134
-0
vllm/v1/attention/backends/mla/triton_mla.py
vllm/v1/attention/backends/mla/triton_mla.py
+10
-7
vllm/v1/attention/ops/triton_decode_attention.py
vllm/v1/attention/ops/triton_decode_attention.py
+47
-0
No files found.
docs/design/attention_backends.md
View file @
a1257fd1
...
@@ -213,5 +213,5 @@ configuration.
...
@@ -213,5 +213,5 @@ configuration.
|
`ROCM_AITER_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_MLA_SPARSE`
| fp16, bf16 |
`auto`
,
`bfloat16`
| 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_MLA_SPARSE`
| fp16, bf16 |
`auto`
,
`bfloat16`
| 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_TRITON_MLA`
| fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_TRITON_MLA`
| fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`TRITON_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
`TRITON_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
`XPU_MLA_SPARSE`
| fp16, bf16 |
`auto`
,
`bfloat16`
| Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
|
`XPU_MLA_SPARSE`
| fp16, bf16 |
`auto`
,
`bfloat16`
| Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
tests/kernels/attention/test_triton_decode_attention.py
View file @
a1257fd1
...
@@ -90,3 +90,137 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
...
@@ -90,3 +90,137 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
)
)
assert
torch
.
allclose
(
o
,
o1
)
assert
torch
.
allclose
(
o
,
o1
)
def
_quantize_to_fp8
(
tensor
:
torch
.
Tensor
):
"""Quantize a BF16 tensor to FP8 e4m3fn with per-tensor scale.
Returns (fp8_tensor, scale) where:
fp8_tensor ≈ tensor / scale (stored as float8_e4m3fn)
tensor ≈ fp8_tensor.to(float32) * scale (dequantized)
"""
amax
=
tensor
.
abs
().
amax
()
# float8_e4m3fn max representable value is 448.0
scale
=
(
amax
/
448.0
).
clamp
(
min
=
1e-12
).
to
(
torch
.
float32
)
fp8_tensor
=
(
(
tensor
.
to
(
torch
.
float32
)
/
scale
).
clamp
(
-
448.0
,
448.0
).
to
(
torch
.
float8_e4m3fn
)
)
return
fp8_tensor
,
scale
@
pytest
.
mark
.
parametrize
(
"B"
,
[
3
])
@
pytest
.
mark
.
parametrize
(
"L"
,
[
1025
])
@
pytest
.
mark
.
parametrize
(
"H_Q"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"H_KV"
,
[
32
,
8
])
@
pytest
.
mark
.
parametrize
(
"D_QK"
,
[
128
,
576
])
@
pytest
.
mark
.
parametrize
(
"D_V"
,
[
128
,
512
])
@
pytest
.
mark
.
parametrize
(
"CACHE_SIZE"
,
[
16384
])
@
pytest
.
mark
.
parametrize
(
"PAGE_SIZE"
,
[
1
,
16
])
def
test_decode_attention_fp8
(
B
,
L
,
H_Q
,
H_KV
,
D_QK
,
D_V
,
CACHE_SIZE
,
PAGE_SIZE
):
"""Test FP8 KV cache path: quantize K/V to FP8, run kernel with scales,
and compare against BF16 reference output."""
assert
CACHE_SIZE
%
PAGE_SIZE
==
0
dtype
=
torch
.
bfloat16
seq_len
=
L
sm_scale
=
1.0
/
(
D_QK
**
0.5
)
num_kv_splits
=
8
num_pages_per_batch
=
cdiv
(
seq_len
,
PAGE_SIZE
)
req_to_page
=
torch
.
randint
(
0
,
CACHE_SIZE
//
PAGE_SIZE
,
(
B
,
num_pages_per_batch
,
1
),
device
=
"cuda"
)
req_to_token
=
req_to_page
*
PAGE_SIZE
req_to_token
=
req_to_token
.
expand
(
B
,
num_pages_per_batch
,
PAGE_SIZE
)
req_to_token
=
req_to_token
+
torch
.
arange
(
PAGE_SIZE
,
device
=
"cuda"
).
view
(
1
,
1
,
-
1
)
req_to_token
=
req_to_token
.
view
(
B
,
-
1
)
req_to_token
=
req_to_token
[:,
:
seq_len
].
contiguous
()
q
=
torch
.
randn
(
B
,
H_Q
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
# Create BF16 K/V as reference
k_bf16
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
v_bf16
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
# --- BF16 reference ---
o_ref
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
lse_ref
=
torch
.
zeros
(
B
,
H_Q
,
dtype
=
dtype
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
PAGE_SIZE
==
1
:
decode_attention_fwd
(
q
,
k_bf16
,
v_bf16
,
o_ref
,
lse_ref
,
req_to_token
,
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
),
attn_logits
=
attn_logits
,
num_kv_splits
=
num_kv_splits
,
sm_scale
=
sm_scale
,
)
else
:
k_paged
=
k_bf16
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_QK
)
v_paged
=
v_bf16
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_V
)
decode_attention_fwd
(
q
,
k_paged
,
v_paged
,
o_ref
,
lse_ref
,
req_to_page
,
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
),
attn_logits
=
attn_logits
,
num_kv_splits
=
num_kv_splits
,
sm_scale
=
sm_scale
,
page_size
=
PAGE_SIZE
,
)
# --- FP8 path ---
k_fp8
,
k_scale
=
_quantize_to_fp8
(
k_bf16
)
v_fp8
,
v_scale
=
_quantize_to_fp8
(
v_bf16
)
o_fp8
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
lse_fp8
=
torch
.
zeros
(
B
,
H_Q
,
dtype
=
dtype
,
device
=
"cuda"
)
attn_logits_fp8
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
PAGE_SIZE
==
1
:
decode_attention_fwd
(
q
,
k_fp8
,
v_fp8
,
o_fp8
,
lse_fp8
,
req_to_token
,
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
),
attn_logits
=
attn_logits_fp8
,
num_kv_splits
=
num_kv_splits
,
sm_scale
=
sm_scale
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
)
else
:
k_fp8_paged
=
k_fp8
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_QK
)
v_fp8_paged
=
v_fp8
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_V
)
decode_attention_fwd
(
q
,
k_fp8_paged
,
v_fp8_paged
,
o_fp8
,
lse_fp8
,
req_to_page
,
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
),
attn_logits
=
attn_logits_fp8
,
num_kv_splits
=
num_kv_splits
,
sm_scale
=
sm_scale
,
page_size
=
PAGE_SIZE
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
)
# FP8 tolerances match test_mla_backends.py test_backend_correctness.
torch
.
testing
.
assert_close
(
o_ref
,
o_fp8
,
atol
=
5e-1
,
rtol
=
1e-2
)
vllm/v1/attention/backends/mla/triton_mla.py
View file @
a1257fd1
...
@@ -32,6 +32,8 @@ class TritonMLABackend(MLACommonBackend):
...
@@ -32,6 +32,8 @@ class TritonMLABackend(MLACommonBackend):
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"auto"
,
"bfloat16"
,
"bfloat16"
,
"fp8"
,
"fp8_e4m3"
,
]
]
@
classmethod
@
classmethod
...
@@ -108,10 +110,11 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -108,10 +110,11 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
"TritonMLAImpl"
"TritonMLAImpl"
)
)
# For FP8 KV cache, we dequantize to BF16 on load inside the
# Triton kernel. Tell the common layer not to quantize queries
# to FP8 — we handle FP8 KV cache with BF16 queries (Mode 1).
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
self
.
supports_quant_query_input
=
False
"TritonMLA V1 with FP8 KV cache not yet supported"
)
def
_flash_attn_varlen_diff_headdims
(
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
return_softmax_lse
=
False
,
softmax_scale
=
None
,
**
kwargs
self
,
q
,
k
,
v
,
return_softmax_lse
=
False
,
softmax_scale
=
None
,
**
kwargs
...
@@ -135,9 +138,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -135,9 +138,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
NotImplementedError
(
"FP8 Triton MLA not yet supported"
)
if
type
(
q
)
is
tuple
:
if
type
(
q
)
is
tuple
:
q
=
torch
.
cat
(
q
,
dim
=-
1
)
q
=
torch
.
cat
(
q
,
dim
=-
1
)
...
@@ -171,7 +171,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -171,7 +171,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
kv_c_cache
=
kv_c_and_k_pe_cache
[...,
:
self
.
kv_lora_rank
]
kv_c_cache
=
kv_c_and_k_pe_cache
[...,
:
self
.
kv_lora_rank
]
PAGE_SIZE
=
kv_c_and_k_pe_cache
.
size
(
1
)
PAGE_SIZE
=
kv_c_and_k_pe_cache
.
size
(
1
)
# Run MQA
# Run MQA — always pass layer scales. When KV cache is
# BF16 the kernel's `if dtype.is_fp8()` check is a no-op.
decode_attention_fwd
(
decode_attention_fwd
(
q
,
q
,
kv_c_and_k_pe_cache
,
kv_c_and_k_pe_cache
,
...
@@ -184,6 +185,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -184,6 +185,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
num_kv_splits
,
num_kv_splits
,
self
.
scale
,
self
.
scale
,
PAGE_SIZE
,
PAGE_SIZE
,
k_scale
=
layer
.
_k_scale
,
v_scale
=
layer
.
_v_scale
,
)
)
return
o
,
lse
return
o
,
lse
vllm/v1/attention/ops/triton_decode_attention.py
View file @
a1257fd1
...
@@ -31,6 +31,7 @@ It supports page size >= 1.
...
@@ -31,6 +31,7 @@ It supports page size >= 1.
import
logging
import
logging
import
torch
from
packaging
import
version
from
packaging
import
version
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -74,6 +75,8 @@ def _fwd_kernel_stage1(
...
@@ -74,6 +75,8 @@ def _fwd_kernel_stage1(
stride_mid_ob
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_os
,
k_scale
,
v_scale
,
kv_group_num
:
tl
.
constexpr
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
...
@@ -109,6 +112,8 @@ def _fwd_kernel_stage1(
...
@@ -109,6 +112,8 @@ def _fwd_kernel_stage1(
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
if
split_kv_end
>
split_kv_start
:
ks
=
tl
.
load
(
k_scale
)
vs
=
tl
.
load
(
v_scale
)
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
kv_page_number
=
tl
.
load
(
...
@@ -129,6 +134,8 @@ def _fwd_kernel_stage1(
...
@@ -129,6 +134,8 @@ def _fwd_kernel_stage1(
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_d
[
None
,
:]),
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
other
=
0.0
,
)
)
if
k
.
dtype
.
is_fp8
():
k
=
(
k
.
to
(
tl
.
float32
)
*
ks
).
to
(
q
.
dtype
)
qk
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
qk
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
qk
*=
sm_scale
qk
*=
sm_scale
...
@@ -147,6 +154,8 @@ def _fwd_kernel_stage1(
...
@@ -147,6 +154,8 @@ def _fwd_kernel_stage1(
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
other
=
0.0
,
)
)
if
v
.
dtype
.
is_fp8
():
v
=
(
v
.
to
(
tl
.
float32
)
*
vs
).
to
(
q
.
dtype
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
0
),
e_max
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
0
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
@@ -194,6 +203,8 @@ def _decode_att_m_fwd(
...
@@ -194,6 +203,8 @@ def _decode_att_m_fwd(
sm_scale
,
sm_scale
,
page_size
,
page_size
,
logit_cap
,
logit_cap
,
k_scale
,
v_scale
,
):
):
BLOCK
=
64
if
not
is_hip_
else
8
BLOCK
=
64
if
not
is_hip_
else
8
...
@@ -231,6 +242,8 @@ def _decode_att_m_fwd(
...
@@ -231,6 +242,8 @@ def _decode_att_m_fwd(
att_out
.
stride
(
0
),
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
att_out
.
stride
(
2
),
k_scale
,
v_scale
,
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
...
@@ -264,6 +277,8 @@ def _fwd_grouped_kernel_stage1(
...
@@ -264,6 +277,8 @@ def _fwd_grouped_kernel_stage1(
stride_mid_ob
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_os
,
k_scale
,
v_scale
,
kv_group_num
:
tl
.
constexpr
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
...
@@ -316,6 +331,8 @@ def _fwd_grouped_kernel_stage1(
...
@@ -316,6 +331,8 @@ def _fwd_grouped_kernel_stage1(
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
if
split_kv_end
>
split_kv_start
:
ks
=
tl
.
load
(
k_scale
)
vs
=
tl
.
load
(
v_scale
)
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
kv_page_number
=
tl
.
load
(
...
@@ -336,6 +353,8 @@ def _fwd_grouped_kernel_stage1(
...
@@ -336,6 +353,8 @@ def _fwd_grouped_kernel_stage1(
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
other
=
0.0
,
)
)
if
k
.
dtype
.
is_fp8
():
k
=
(
k
.
to
(
tl
.
float32
)
*
ks
).
to
(
q
.
dtype
)
qk
=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
qk
=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
offs_buf_kpe
=
(
...
@@ -348,6 +367,8 @@ def _fwd_grouped_kernel_stage1(
...
@@ -348,6 +367,8 @@ def _fwd_grouped_kernel_stage1(
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
other
=
0.0
,
other
=
0.0
,
)
)
if
kpe
.
dtype
.
is_fp8
():
kpe
=
(
kpe
.
to
(
tl
.
float32
)
*
ks
).
to
(
qpe
.
dtype
)
qk
+=
tl
.
dot
(
qpe
,
kpe
.
to
(
qpe
.
dtype
))
qk
+=
tl
.
dot
(
qpe
,
kpe
.
to
(
qpe
.
dtype
))
qk
*=
sm_scale
qk
*=
sm_scale
...
@@ -368,6 +389,8 @@ def _fwd_grouped_kernel_stage1(
...
@@ -368,6 +389,8 @@ def _fwd_grouped_kernel_stage1(
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
other
=
0.0
,
)
)
if
v
.
dtype
.
is_fp8
():
v
=
(
v
.
to
(
tl
.
float32
)
*
vs
).
to
(
q
.
dtype
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
@@ -416,6 +439,8 @@ def _decode_grouped_att_m_fwd(
...
@@ -416,6 +439,8 @@ def _decode_grouped_att_m_fwd(
sm_scale
,
sm_scale
,
page_size
,
page_size
,
logit_cap
,
logit_cap
,
k_scale
,
v_scale
,
):
):
BLOCK
=
32
BLOCK
=
32
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
...
@@ -473,6 +498,8 @@ def _decode_grouped_att_m_fwd(
...
@@ -473,6 +498,8 @@ def _decode_grouped_att_m_fwd(
att_out
.
stride
(
0
),
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
att_out
.
stride
(
2
),
k_scale
,
v_scale
,
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
...
@@ -609,6 +636,8 @@ def decode_attention_fwd_normal(
...
@@ -609,6 +636,8 @@ def decode_attention_fwd_normal(
sm_scale
,
sm_scale
,
page_size
,
page_size
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
k_scale
=
None
,
v_scale
=
None
,
):
):
_decode_att_m_fwd
(
_decode_att_m_fwd
(
q
,
q
,
...
@@ -621,6 +650,8 @@ def decode_attention_fwd_normal(
...
@@ -621,6 +650,8 @@ def decode_attention_fwd_normal(
sm_scale
,
sm_scale
,
page_size
,
page_size
,
logit_cap
,
logit_cap
,
k_scale
,
v_scale
,
)
)
_decode_softmax_reducev_fwd
(
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
lse
,
v_buffer
,
b_seq_len
,
num_kv_splits
attn_logits
,
q
,
o
,
lse
,
v_buffer
,
b_seq_len
,
num_kv_splits
...
@@ -640,6 +671,8 @@ def decode_attention_fwd_grouped(
...
@@ -640,6 +671,8 @@ def decode_attention_fwd_grouped(
sm_scale
,
sm_scale
,
page_size
,
page_size
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
k_scale
=
None
,
v_scale
=
None
,
):
):
_decode_grouped_att_m_fwd
(
_decode_grouped_att_m_fwd
(
q
,
q
,
...
@@ -652,6 +685,8 @@ def decode_attention_fwd_grouped(
...
@@ -652,6 +685,8 @@ def decode_attention_fwd_grouped(
sm_scale
,
sm_scale
,
page_size
,
page_size
,
logit_cap
,
logit_cap
,
k_scale
,
v_scale
,
)
)
_decode_softmax_reducev_fwd
(
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
lse
,
v_buffer
,
b_seq_len
,
num_kv_splits
attn_logits
,
q
,
o
,
lse
,
v_buffer
,
b_seq_len
,
num_kv_splits
...
@@ -671,8 +706,16 @@ def decode_attention_fwd(
...
@@ -671,8 +706,16 @@ def decode_attention_fwd(
sm_scale
,
sm_scale
,
page_size
=
1
,
page_size
=
1
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
k_scale
=
None
,
v_scale
=
None
,
):
):
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
if
k_scale
is
None
:
k_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
if
v_scale
is
None
:
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
-
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
-
2
]
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
...
@@ -690,6 +733,8 @@ def decode_attention_fwd(
...
@@ -690,6 +733,8 @@ def decode_attention_fwd(
sm_scale
,
sm_scale
,
page_size
,
page_size
,
logit_cap
,
logit_cap
,
k_scale
,
v_scale
,
)
)
else
:
else
:
# GQA/MQA/MLA
# GQA/MQA/MLA
...
@@ -706,4 +751,6 @@ def decode_attention_fwd(
...
@@ -706,4 +751,6 @@ def decode_attention_fwd(
sm_scale
,
sm_scale
,
page_size
,
page_size
,
logit_cap
,
logit_cap
,
k_scale
,
v_scale
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment