Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e584dce5
Unverified
Commit
e584dce5
authored
Mar 11, 2026
by
Wuxun Zhang
Committed by
GitHub
Mar 11, 2026
Browse files
Add XPU MLA Sparse backend for DeepSeek v3.2 (#33230)
Signed-off-by:
Zhang, Wuxun
<
wuxun.zhang@intel.com
>
parent
40c0461f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
940 additions
and
24 deletions
+940
-24
docs/design/attention_backends.md
docs/design/attention_backends.md
+1
-0
tests/kernels/attention/test_xpu_mla_sparse.py
tests/kernels/attention/test_xpu_mla_sparse.py
+118
-0
vllm/_xpu_ops.py
vllm/_xpu_ops.py
+245
-0
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+47
-22
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+2
-1
vllm/triton_utils/__init__.py
vllm/triton_utils/__init__.py
+4
-1
vllm/v1/attention/backends/mla/xpu_mla_sparse.py
vllm/v1/attention/backends/mla/xpu_mla_sparse.py
+257
-0
vllm/v1/attention/backends/registry.py
vllm/v1/attention/backends/registry.py
+1
-0
vllm/v1/attention/ops/xpu_mla_sparse.py
vllm/v1/attention/ops/xpu_mla_sparse.py
+265
-0
No files found.
docs/design/attention_backends.md
View file @
e584dce5
...
@@ -214,3 +214,4 @@ configuration.
...
@@ -214,3 +214,4 @@ configuration.
|
`ROCM_AITER_MLA_SPARSE`
| fp16, bf16 |
`auto`
,
`bfloat16`
| 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_MLA_SPARSE`
| fp16, bf16 |
`auto`
,
`bfloat16`
| 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_TRITON_MLA`
| fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_TRITON_MLA`
| fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`TRITON_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
`TRITON_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
`XPU_MLA_SPARSE`
| fp16, bf16 |
`auto`
,
`bfloat16`
| Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
tests/kernels/attention/test_xpu_mla_sparse.py
0 → 100644
View file @
e584dce5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.v1.attention.ops.xpu_mla_sparse
import
triton_bf16_mla_sparse_interface
# https://github.com/deepseek-ai/FlashMLA/blob/main/tests/ref.py#L7
def
_merge_two_lse
(
lse0
:
torch
.
Tensor
,
lse1
:
torch
.
Tensor
|
None
,
s_q
:
int
,
h_q
:
int
)
->
torch
.
Tensor
:
if
lse1
is
None
:
return
lse0
else
:
return
torch
.
logsumexp
(
torch
.
stack
([
lse0
.
view
(
s_q
,
h_q
),
lse1
.
broadcast_to
(
s_q
,
h_q
)],
dim
=
0
),
dim
=
0
,
)
# Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/tests/ref.py#L19
def
reference_mla_sparse_prefill
(
q
:
torch
.
Tensor
,
kv
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
sm_scale
:
float
,
d_v
:
int
,
topk_length
:
torch
.
Tensor
|
None
=
None
,
attn_sink
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Returns:
- o: [s_q, h_q, dv]
- o_fp32: [s_q, h_q, dv]
- max_logits: [s_q, h_q]
- lse: [s_q, h_q]
"""
s_q
,
h_q
,
d_qk
=
q
.
shape
s_kv
,
_
,
_
=
kv
.
shape
_
,
_
,
topk
=
indices
.
shape
indices
=
indices
.
clone
().
squeeze
(
1
)
if
topk_length
is
not
None
:
mask
=
torch
.
arange
(
topk
,
device
=
topk_length
.
device
).
unsqueeze
(
0
).
broadcast_to
(
s_q
,
topk
)
>=
topk_length
.
unsqueeze
(
1
)
# [s_q, topk]
indices
[
mask
]
=
-
1
invalid_mask
=
(
indices
<
0
)
|
(
indices
>=
s_kv
)
# [s_q, topk]
indices
[
invalid_mask
]
=
0
q
=
q
.
float
()
gathered_kv
=
(
kv
.
index_select
(
dim
=
0
,
index
=
indices
.
flatten
()).
reshape
(
s_q
,
topk
,
d_qk
).
float
()
)
# [s_q, topk, d_qk]
P
=
q
@
gathered_kv
.
transpose
(
1
,
2
)
# [s_q, h_q, topk]
P
*=
sm_scale
P
[
invalid_mask
.
unsqueeze
(
1
).
broadcast_to
(
P
.
shape
)]
=
float
(
"-inf"
)
orig_lse
=
torch
.
logsumexp
(
P
,
dim
=-
1
)
# [s_q, h_q]
max_logits
=
P
.
max
(
dim
=-
1
).
values
# [s_q, h_q]
lse_for_o
=
_merge_two_lse
(
orig_lse
,
attn_sink
,
s_q
,
h_q
)
if
not
torch
.
is_inference_mode_enabled
():
lse_for_o
=
lse_for_o
.
clone
()
lse_for_o
[
lse_for_o
==
float
(
"-inf"
)]
=
float
(
"+inf"
)
# So that corresponding O will be 0
s_for_o
=
torch
.
exp
(
P
-
lse_for_o
.
unsqueeze
(
-
1
))
out
=
s_for_o
@
gathered_kv
[...,
:
d_v
]
# [s_q, h_q, dv]
lonely_q_mask
=
orig_lse
==
float
(
"-inf"
)
# [s_q, h_q]
orig_lse
[
lonely_q_mask
]
=
float
(
"+inf"
)
return
(
out
.
to
(
kv
.
dtype
),
out
,
max_logits
,
orig_lse
)
@
pytest
.
mark
.
parametrize
(
"device_str"
,
[
"xpu"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
skipif
(
not
torch
.
xpu
.
is_available
(),
reason
=
"XPU is required"
,
)
def
test_bf16_triton_sparse_mla
(
device_str
,
dtype
):
device
=
torch
.
device
(
device_str
)
s_q
=
1
s_kv
=
256
h_q
=
64
# kernel expects multiple of 64
h_kv
=
1
d_qk
=
576
d_v
=
512
topk
=
128
torch
.
random
.
manual_seed
(
1234
)
q
=
torch
.
randn
((
s_q
,
h_q
,
d_qk
),
dtype
=
dtype
,
device
=
device
)
kv
=
torch
.
randn
((
s_kv
,
h_kv
,
d_qk
),
dtype
=
dtype
,
device
=
device
)
indices
=
torch
.
full
((
s_q
,
h_kv
,
topk
),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
for
t
in
range
(
s_q
):
for
h
in
range
(
h_kv
):
i_i
=
torch
.
randperm
(
max
(
1
,
t
))[:
topk
]
indices
[
t
,
h
,
:
len
(
i_i
)]
=
i_i
sm_scale
=
d_qk
**-
0.5
out
,
max_logits
,
lse
=
triton_bf16_mla_sparse_interface
(
q
,
kv
,
indices
,
sm_scale
,
d_v
)
assert
out
.
shape
==
(
s_q
,
h_q
,
d_v
)
assert
max_logits
.
shape
==
(
s_q
,
h_q
)
assert
lse
.
shape
==
(
s_q
,
h_q
)
ref_out
,
ref_out_fp32
,
ref_max_logits
,
ref_lse
=
reference_mla_sparse_prefill
(
q
,
kv
,
indices
,
sm_scale
,
d_v
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
max_logits
,
ref_max_logits
,
atol
=
1e-3
,
rtol
=
1e-3
)
assert
torch
.
allclose
(
lse
,
ref_lse
,
atol
=
1e-3
,
rtol
=
1e-3
)
vllm/_xpu_ops.py
View file @
e584dce5
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
from
vllm_xpu_kernels.flash_attn_interface
import
flash_attn_varlen_func
from
vllm_xpu_kernels.flash_attn_interface
import
flash_attn_varlen_func
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -157,3 +158,247 @@ class xpu_ops:
...
@@ -157,3 +158,247 @@ class xpu_ops:
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
)
)
return
None
return
None
@
staticmethod
def
indexer_k_quant_and_cache
(
k
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
)
->
None
:
head_dim
=
k
.
shape
[
-
1
]
k
=
k
.
view
(
-
1
,
head_dim
)
# [total_tokens, head_dim]
def
group_quant_torch
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
|
None
=
None
,
column_major_scales
:
bool
=
False
,
out_q
:
torch
.
Tensor
|
None
=
None
,
use_ue8m0
:
bool
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
use_ue8m0
is
None
:
# Default fallback - could import is_deep_gemm_e8m0_used if needed
use_ue8m0
=
False
if
dtype
is
None
:
dtype
=
current_platform
.
fp8_dtype
()
# Validate inputs
assert
x
.
shape
[
-
1
]
%
group_size
==
0
,
(
f
"Last dimension
{
x
.
shape
[
-
1
]
}
must be divisible by "
f
"group_size
{
group_size
}
"
)
assert
x
.
stride
(
-
1
)
==
1
,
"Input tensor groups must be contiguous"
# Prepare output tensor
if
out_q
is
None
:
x_q
=
torch
.
empty_like
(
x
,
dtype
=
dtype
)
else
:
assert
out_q
.
shape
==
x
.
shape
x_q
=
out_q
# Reshape input for group processing
# Original shape: (..., last_dim)
# Target shape: (..., num_groups, group_size)
original_shape
=
x
.
shape
num_groups
=
original_shape
[
-
1
]
//
group_size
# Reshape to separate groups
group_shape
=
original_shape
[:
-
1
]
+
(
num_groups
,
group_size
)
x_grouped
=
x
.
view
(
group_shape
)
# Compute per-group absolute maximum values
# Shape: (..., num_groups)
abs_max
=
torch
.
amax
(
torch
.
abs
(
x_grouped
),
dim
=-
1
,
keepdim
=
False
)
abs_max
=
torch
.
maximum
(
abs_max
,
torch
.
tensor
(
eps
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
)
# Compute scales
FP8_MAX
=
torch
.
finfo
(
dtype
).
max
FP8_MIN
=
torch
.
finfo
(
dtype
).
min
scale_raw
=
abs_max
/
FP8_MAX
if
use_ue8m0
:
# For UE8M0 format, scales must be powers of 2
scales
=
torch
.
pow
(
2.0
,
torch
.
ceil
(
torch
.
log2
(
scale_raw
)))
else
:
scales
=
scale_raw
# Expand scales for broadcasting with grouped data
# Shape: (..., num_groups, 1)
scales_expanded
=
scales
.
unsqueeze
(
-
1
)
# Quantize the grouped data
x_scaled
=
x_grouped
/
scales_expanded
x_clamped
=
torch
.
clamp
(
x_scaled
,
FP8_MIN
,
FP8_MAX
)
x_quantized
=
x_clamped
.
to
(
dtype
)
# Reshape back to original shape
x_q
.
copy_
(
x_quantized
.
view
(
original_shape
))
# Prepare scales tensor in requested format
if
column_major_scales
:
# Column-major: (num_groups,) + batch_dims
# Transpose the scales to put group dimension first
scales_shape
=
(
num_groups
,)
+
original_shape
[:
-
1
]
x_s
=
scales
.
permute
(
-
1
,
*
range
(
len
(
original_shape
)
-
1
))
x_s
=
x_s
.
contiguous
().
view
(
scales_shape
)
else
:
# Row-major: batch_dims + (num_groups,)
x_s
=
scales
.
contiguous
()
# Ensure scales are float32
return
x_q
,
x_s
.
float
()
k_fp8
,
k_scale
=
group_quant_torch
(
k
,
group_size
=
quant_block_size
,
column_major_scales
=
False
,
use_ue8m0
=
(
scale_fmt
==
"ue8m0"
),
)
k_fp8_bytes
=
k_fp8
.
view
(
-
1
,
head_dim
).
view
(
torch
.
uint8
)
scale_bytes
=
k_scale
.
view
(
torch
.
uint8
).
view
(
-
1
,
4
)
k
=
torch
.
cat
(
[
k_fp8_bytes
,
scale_bytes
],
dim
=-
1
)
# [total_tokens, head_dim + 4]
slot_mapping
=
slot_mapping
.
flatten
()
# kv_cache: [num_block, block_size, head_dim + 4]
kv_cache
.
view
(
-
1
,
kv_cache
.
shape
[
-
1
]).
index_copy_
(
0
,
slot_mapping
,
k
)
@
staticmethod
def
cp_gather_indexer_k_quant_cache
(
kv_cache
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
dst_scale
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cu_seq_lens
:
torch
.
Tensor
,
)
->
None
:
"""
Args:
kv_cache: [num_blocks, block_size, cache_stride] - quantized KV cache
Layout per block: [k_values, scale_values]
- k_values: [block_size * head_dim]
- scale_values: [block_size * head_dim * 4 / quant_block_size]
dst_k: [num_tokens, head_dim] - output tensor for K values
dst_scale: [num_tokens, head_dim / quant_block_size * 4]
- output tensor for scale values
block_table: [batch_size, num_blocks] - block table for indexing
cu_seq_lens: [batch_size + 1] - cumulative sequence lengths
"""
batch_size
=
block_table
.
size
(
0
)
num_tokens
=
dst_k
.
size
(
0
)
head_dim
=
dst_k
.
size
(
1
)
cache_block_size
=
kv_cache
.
size
(
1
)
quant_block_size
=
head_dim
*
4
//
dst_scale
.
size
(
1
)
# For each token, find which batch it belongs to using searchsorted
token_indices
=
torch
.
arange
(
num_tokens
,
device
=
dst_k
.
device
)
+
1
# cu_seq_lens is [batch_size + 1], we need to find which interval each
# token belongs to
batch_indices
=
torch
.
searchsorted
(
cu_seq_lens
,
token_indices
)
-
1
batch_indices
=
torch
.
clamp
(
batch_indices
,
0
,
batch_size
-
1
)
# Calculate the in-batch sequence index for each token
inbatch_seq_indices
=
token_indices
-
cu_seq_lens
[
batch_indices
]
# Find which block each token belongs to
block_indices_in_table
=
inbatch_seq_indices
//
cache_block_size
physical_block_indices
=
block_table
[
batch_indices
,
block_indices_in_table
]
# Calculate the offset within each block
inblock_offsets
=
(
inbatch_seq_indices
-
1
)
%
cache_block_size
# Calculate strides
block_stride
=
kv_cache
.
stride
(
0
)
# stride for each block
# Flatten kv_cache for easier indexing
kv_cache_flat
=
kv_cache
.
view
(
-
1
)
# Calculate source offset for K values for all tokens (vectorized)
src_block_offsets
=
physical_block_indices
*
block_stride
src_k_offsets
=
src_block_offsets
+
inblock_offsets
*
head_dim
# Gather K values using advanced indexing
# Create indices for all elements we need to gather
k_indices
=
src_k_offsets
.
unsqueeze
(
1
)
+
torch
.
arange
(
head_dim
,
device
=
dst_k
.
device
)
dst_k
[:]
=
kv_cache_flat
[
k_indices
]
# Calculate source offset for scale values (vectorized)
# Scales are stored after all K values for each block
scale_size
=
head_dim
*
4
//
quant_block_size
src_scale_offsets
=
src_block_offsets
+
head_dim
+
inblock_offsets
*
scale_size
# Gather scale values
scale_indices
=
src_scale_offsets
.
unsqueeze
(
1
)
+
torch
.
arange
(
scale_size
,
device
=
dst_scale
.
device
)
dst_scale
[:]
=
kv_cache_flat
[
scale_indices
]
@
staticmethod
def
top_k_per_row_prefill
(
logits
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
cu_seqlen_ke
:
torch
.
Tensor
,
raw_topk_indices
:
torch
.
Tensor
,
num_rows
:
int
,
stride0
:
int
,
strdide1
:
int
,
topk_tokens
:
int
,
)
->
torch
.
Tensor
:
real_topk
=
min
(
topk_tokens
,
logits
.
shape
[
-
1
])
topk_indices
=
logits
.
topk
(
real_topk
,
dim
=-
1
)[
1
].
to
(
torch
.
int32
)
topk_indices
-=
cu_seqlen_ks
[:,
None
]
mask_lo
=
topk_indices
>=
0
mask_hi
=
topk_indices
-
(
cu_seqlen_ke
-
cu_seqlen_ks
)[:,
None
]
<
0
mask
=
torch
.
full_like
(
topk_indices
,
False
,
dtype
=
torch
.
bool
,
device
=
topk_indices
.
device
)
mask
=
mask_lo
&
mask_hi
topk_indices
.
masked_fill_
(
~
mask
,
-
1
)
raw_topk_indices
[:
topk_indices
.
shape
[
0
],
:
topk_indices
.
shape
[
1
]]
=
(
topk_indices
)
@
staticmethod
def
top_k_per_row_decode
(
logits
:
torch
.
Tensor
,
next_n
:
int
,
seq_lens
:
torch
.
Tensor
,
raw_topk_indices
:
torch
.
Tensor
,
num_rows
:
int
,
stride0
:
int
,
stride1
:
int
,
topk_tokens
:
int
,
)
->
torch
.
Tensor
:
device
=
logits
.
device
batch_size
=
seq_lens
.
size
(
0
)
# padded query len
padded_num_tokens
=
batch_size
*
next_n
positions
=
(
torch
.
arange
(
logits
.
shape
[
-
1
],
device
=
device
)
.
unsqueeze
(
0
)
.
expand
(
batch_size
*
next_n
,
-
1
)
)
row_indices
=
torch
.
arange
(
padded_num_tokens
,
device
=
device
)
//
next_n
next_n_offset
=
torch
.
arange
(
padded_num_tokens
,
device
=
device
)
%
next_n
index_end_pos
=
(
seq_lens
[
row_indices
]
-
next_n
+
next_n_offset
).
unsqueeze
(
1
)
# index_end_pos: [B * N, 1]
mask
=
positions
<=
index_end_pos
# mask: [B * N, L]
logits
=
logits
.
masked_fill
(
~
mask
,
float
(
"-inf"
))
topk_indices
=
logits
.
topk
(
topk_tokens
,
dim
=-
1
)[
1
].
to
(
torch
.
int32
)
# [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices
[
topk_indices
>
index_end_pos
]
=
-
1
raw_topk_indices
[:
topk_indices
.
shape
[
0
],
:
topk_indices
.
shape
[
1
]]
=
(
topk_indices
)
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
e584dce5
...
@@ -135,16 +135,29 @@ def sparse_attn_indexer(
...
@@ -135,16 +135,29 @@ def sparse_attn_indexer(
topk_indices
=
topk_indices_buffer
[
topk_indices
=
topk_indices_buffer
[
chunk
.
token_start
:
chunk
.
token_end
,
:
topk_tokens
chunk
.
token_start
:
chunk
.
token_end
,
:
topk_tokens
]
]
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits
,
if
current_platform
.
is_xpu
():
chunk
.
cu_seqlen_ks
,
ops
.
top_k_per_row_prefill
(
chunk
.
cu_seqlen_ke
,
logits
,
topk_indices
,
chunk
.
cu_seqlen_ks
,
num_rows
,
chunk
.
cu_seqlen_ke
,
logits
.
stride
(
0
),
topk_indices
,
logits
.
stride
(
1
),
num_rows
,
topk_tokens
,
logits
.
stride
(
0
),
)
logits
.
stride
(
1
),
topk_tokens
,
)
else
:
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
# Compute lengths from row spans
# Compute lengths from row spans
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
...
@@ -220,16 +233,28 @@ def sparse_attn_indexer(
...
@@ -220,16 +233,28 @@ def sparse_attn_indexer(
None
,
None
,
)
)
else
:
else
:
torch
.
ops
.
_C
.
top_k_per_row_decode
(
if
current_platform
.
is_xpu
():
logits
,
ops
.
top_k_per_row_decode
(
next_n
,
logits
,
decode_metadata
.
seq_lens
,
next_n
,
topk_indices
,
decode_metadata
.
seq_lens
,
num_rows
,
topk_indices
,
logits
.
stride
(
0
),
num_rows
,
logits
.
stride
(
1
),
logits
.
stride
(
0
),
topk_tokens
,
logits
.
stride
(
1
),
)
topk_tokens
,
)
else
:
torch
.
ops
.
_C
.
top_k_per_row_decode
(
logits
,
next_n
,
decode_metadata
.
seq_lens
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
decode_metadata
.
requires_padding
:
if
decode_metadata
.
requires_padding
:
# if padded, we need to unpack
# if padded, we need to unpack
...
@@ -320,14 +345,14 @@ class SparseAttnIndexer(CustomOp):
...
@@ -320,14 +345,14 @@ class SparseAttnIndexer(CustomOp):
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
):
):
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
()
or
current_platform
.
is_xpu
()
:
return
self
.
forward_cuda
(
hidden_states
,
q_fp8
,
k
,
weights
)
return
self
.
forward_cuda
(
hidden_states
,
q_fp8
,
k
,
weights
)
elif
current_platform
.
is_rocm
():
elif
current_platform
.
is_rocm
():
return
self
.
forward_hip
(
hidden_states
,
q_fp8
,
k
,
weights
)
return
self
.
forward_hip
(
hidden_states
,
q_fp8
,
k
,
weights
)
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"SparseAttnIndexer native forward is only implemented for "
"SparseAttnIndexer native forward is only implemented for "
"CUDA
and
ROCm platform."
"CUDA
,
ROCm
and XPU
platform
s
."
)
)
def
forward_cuda
(
def
forward_cuda
(
...
...
vllm/platforms/xpu.py
View file @
e584dce5
...
@@ -61,7 +61,8 @@ class XPUPlatform(Platform):
...
@@ -61,7 +61,8 @@ class XPUPlatform(Platform):
dtype
=
attn_selector_config
.
dtype
dtype
=
attn_selector_config
.
dtype
if
attn_selector_config
.
use_sparse
:
if
attn_selector_config
.
use_sparse
:
raise
NotImplementedError
(
"Sparse Attention is not supported on XPU."
)
logger
.
info_once
(
"Using XPU MLA Sparse backend."
)
return
AttentionBackendEnum
.
XPU_MLA_SPARSE
.
get_path
()
if
attn_selector_config
.
use_mla
:
if
attn_selector_config
.
use_mla
:
logger
.
info_once
(
"Using Triton MLA backend on V1 engine."
)
logger
.
info_once
(
"Using Triton MLA backend on V1 engine."
)
return
AttentionBackendEnum
.
TRITON_MLA
.
get_path
()
return
AttentionBackendEnum
.
TRITON_MLA
.
get_path
()
...
...
vllm/triton_utils/__init__.py
View file @
e584dce5
...
@@ -17,4 +17,7 @@ else:
...
@@ -17,4 +17,7 @@ else:
tl
=
TritonLanguagePlaceholder
()
tl
=
TritonLanguagePlaceholder
()
tldevice
=
TritonLanguagePlaceholder
()
tldevice
=
TritonLanguagePlaceholder
()
__all__
=
[
"HAS_TRITON"
,
"triton"
,
"tl"
,
"tldevice"
]
LOG2E
=
1.4426950408889634
LOGE2
=
0.6931471805599453
__all__
=
[
"HAS_TRITON"
,
"triton"
,
"tl"
,
"tldevice"
,
"LOG2E"
,
"LOGE2"
]
vllm/v1/attention/backends/mla/xpu_mla_sparse.py
0 → 100644
View file @
e584dce5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Optional
import
numpy
as
np
import
torch
from
vllm.config
import
VllmConfig
from
vllm.config.cache
import
CacheDType
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention.mla_attention
import
(
get_mla_dims
,
)
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionCGSupport
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
SparseMLAAttentionImpl
,
)
from
vllm.v1.attention.backends.mla.flashmla_sparse
import
(
triton_convert_req_index_to_global_index
,
)
from
vllm.v1.attention.ops.xpu_mla_sparse
import
triton_bf16_mla_sparse_interface
from
vllm.v1.kv_cache_interface
import
AttentionSpec
if
TYPE_CHECKING
:
from
vllm.model_executor.models.deepseek_v2
import
Indexer
logger
=
init_logger
(
__name__
)
class
XPUMLASparseBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"bfloat16"
,
]
@
staticmethod
def
get_name
()
->
str
:
return
"XPU_MLA_SPARSE"
@
staticmethod
def
get_metadata_cls
()
->
type
[
"XPUMLASparseMetadata"
]:
return
XPUMLASparseMetadata
@
staticmethod
def
get_builder_cls
()
->
type
[
"XPUMLASparseMetadataBuilder"
]:
return
XPUMLASparseMetadataBuilder
@
staticmethod
def
get_impl_cls
()
->
type
[
"XPUMLASparseImpl"
]:
return
XPUMLASparseImpl
@
classmethod
def
is_mla
(
cls
)
->
bool
:
return
True
@
classmethod
def
is_sparse
(
cls
)
->
bool
:
return
True
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
# assumed to be 1 for MLA
head_size
:
int
,
cache_dtype_str
:
str
=
"auto"
,
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
head_size
)
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
576
]
@
dataclass
class
XPUMLASparseMetadata
(
AttentionMetadata
):
num_reqs
:
int
max_query_len
:
int
max_seq_len
:
int
num_actual_tokens
:
int
# Number of tokens excluding padding.
query_start_loc
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
block_table
:
torch
.
Tensor
req_id_per_token
:
torch
.
Tensor
block_size
:
int
=
1
topk_tokens
:
int
=
2048
@
dataclass
class
XPUMLASparseMetadataBuilder
(
AttentionMetadataBuilder
[
XPUMLASparseMetadata
]):
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
NEVER
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
self
.
kv_cache_spec
=
kv_cache_spec
self
.
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
self
.
device
=
device
max_num_batched_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
self
.
num_heads
=
self
.
model_config
.
get_num_attention_heads
(
parallel_config
)
self
.
mla_dims
=
get_mla_dims
(
self
.
model_config
)
self
.
topk_tokens
=
vllm_config
.
model_config
.
hf_config
.
index_topk
self
.
topk_tokens_tensor
=
torch
.
tensor
(
[
self
.
topk_tokens
],
device
=
device
,
dtype
=
torch
.
int32
)
self
.
max_model_len_tensor
=
torch
.
tensor
(
[
self
.
model_config
.
max_model_len
],
device
=
device
,
dtype
=
torch
.
int32
)
# this is ignored by `flash_mla_with_kvcache` if indices not None
self
.
dummy_block_table
=
torch
.
empty
(
(
1
,
1
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
req_id_per_token_buffer
=
torch
.
empty
(
(
max_num_batched_tokens
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
XPUMLASparseMetadata
:
num_tokens
=
common_attn_metadata
.
num_actual_tokens
starts
=
np
.
asarray
(
common_attn_metadata
.
query_start_loc_cpu
,
dtype
=
np
.
int32
)
seg_lengths
=
np
.
diff
(
starts
)
req_id_per_token
=
np
.
repeat
(
np
.
arange
(
seg_lengths
.
shape
[
0
],
dtype
=
np
.
int32
),
seg_lengths
)
# Zero-fill for cudagraphs
self
.
req_id_per_token_buffer
.
fill_
(
0
)
self
.
req_id_per_token_buffer
[:
req_id_per_token
.
shape
[
0
]].
copy_
(
torch
.
from_numpy
(
req_id_per_token
),
non_blocking
=
True
)
req_id_per_token
=
self
.
req_id_per_token_buffer
[:
num_tokens
]
metadata
=
XPUMLASparseMetadata
(
num_reqs
=
common_attn_metadata
.
num_reqs
,
max_query_len
=
common_attn_metadata
.
max_query_len
,
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
block_table
=
common_attn_metadata
.
block_table_tensor
,
req_id_per_token
=
req_id_per_token
,
block_size
=
self
.
kv_cache_spec
.
block_size
,
topk_tokens
=
self
.
topk_tokens
,
)
return
metadata
class
XPUMLASparseImpl
(
SparseMLAAttentionImpl
[
XPUMLASparseMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
list
[
float
]
|
None
,
sliding_window
:
int
|
None
,
kv_cache_dtype
:
str
,
logits_soft_cap
:
float
|
None
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
str
|
None
,
# MLA Specific Arguments
topk_indice_buffer
:
torch
.
Tensor
|
None
=
None
,
indexer
:
Optional
[
"Indexer"
]
=
None
,
**
mla_args
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_lora_rank
:
int
=
mla_args
[
"kv_lora_rank"
]
self
.
softmax_scale
=
scale
assert
indexer
is
not
None
self
.
topk_indices_buffer
:
torch
.
Tensor
|
None
=
indexer
.
topk_indices_buffer
def
_forward_bf16_kv
(
self
,
q
:
torch
.
Tensor
,
# [sq, heads, d_qk]
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
# [blocks, heads, d_qk]
topk_indices
:
torch
.
Tensor
,
# [sq, topk]
attn_metadata
:
XPUMLASparseMetadata
,
)
->
torch
.
Tensor
:
num_tokens
=
q
.
shape
[
0
]
kv_c_and_k_pe_cache
=
kv_c_and_k_pe_cache
.
view
(
-
1
,
1
,
kv_c_and_k_pe_cache
.
shape
[
-
1
]
)
topk_indices
=
topk_indices
.
view
(
num_tokens
,
1
,
-
1
)
output
,
_
,
_
=
triton_bf16_mla_sparse_interface
(
q
,
kv_c_and_k_pe_cache
,
topk_indices
,
sm_scale
=
self
.
softmax_scale
,
)
return
output
[:,
:
self
.
num_heads
,
:]
def
forward_mqa
(
self
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
XPUMLASparseMetadata
,
layer
:
AttentionLayer
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
NotImplementedError
(
"FP8 kv is not supported with XPU MLA Sparse yet"
)
# Concatenate q if it's a tuple (ql_nope, q_pe)
if
isinstance
(
q
,
tuple
):
q
=
torch
.
cat
(
q
,
dim
=-
1
)
num_actual_toks
=
q
.
shape
[
0
]
assert
self
.
topk_indices_buffer
is
not
None
topk_indices
=
self
.
topk_indices_buffer
[:
num_actual_toks
]
topk_indices_global
=
triton_convert_req_index_to_global_index
(
attn_metadata
.
req_id_per_token
,
attn_metadata
.
block_table
,
topk_indices
,
BLOCK_SIZE
=
attn_metadata
.
block_size
,
NUM_TOPK_TOKENS
=
attn_metadata
.
topk_tokens
,
)
attn_out
=
self
.
_forward_bf16_kv
(
q
,
kv_c_and_k_pe_cache
,
topk_indices_global
,
attn_metadata
)
return
attn_out
,
None
vllm/v1/attention/backends/registry.py
View file @
e584dce5
...
@@ -57,6 +57,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
...
@@ -57,6 +57,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
ROCM_AITER_MLA_SPARSE
=
(
ROCM_AITER_MLA_SPARSE
=
(
"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
)
)
XPU_MLA_SPARSE
=
"vllm.v1.attention.backends.mla.xpu_mla_sparse.XPUMLASparseBackend"
TORCH_SDPA
=
""
# this tag is only used for ViT
TORCH_SDPA
=
""
# this tag is only used for ViT
FLASHINFER
=
"vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER
=
"vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA
=
(
FLASHINFER_MLA
=
(
...
...
vllm/v1/attention/ops/xpu_mla_sparse.py
0 → 100644
View file @
e584dce5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.triton_utils
import
LOG2E
,
LOGE2
,
tl
,
triton
@
triton
.
jit
def
_bf16_mla_sparse_kernel
(
q_buffer
,
k_buffer
,
v_buffer
,
indices_ptr
,
out_ptr
,
softmax_lse_ptr
,
max_logits_ptr
,
seq_q
,
seq_kv
,
h_q
,
dim_qk
,
dim_v
,
stride_q_token
,
stride_q_head
,
stride_k_token
,
stride_k_head
,
stride_v_token
,
stride_v_head
,
stride_out_token
,
stride_out_head
,
stride_lse
,
stride_indices_token
,
stride_indices_head
,
sm_scale
,
kv_group_num
:
tl
.
constexpr
,
index_topk
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
# block size for num heads
BLOCK_M
:
tl
.
constexpr
,
# block size for num tokens
BLOCK_N
:
tl
.
constexpr
,
# block size for indices
BLOCK_DV
:
tl
.
constexpr
,
# block size for dim_v
BLOCK_DMODEL
:
tl
.
constexpr
,
# block size for dim_nope
BLOCK_DPE
:
tl
.
constexpr
,
# block size for positional embedding
LOGE2
:
tl
.
constexpr
,
):
cur_q
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head_id
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
if
kv_group_num
>
BLOCK_H
else
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
h_q
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
off_q
=
cur_q
*
stride_q_token
+
cur_head
[:,
None
]
*
stride_q_head
+
offs_d
[
None
,
:]
mask_dmodel
=
offs_d
<
BLOCK_DMODEL
q
=
tl
.
load
(
q_buffer
+
off_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dmodel
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
off_qpe
=
(
cur_q
*
stride_q_token
+
cur_head
[:,
None
]
*
stride_q_head
+
offs_dpe
[
None
,
:]
)
# assume dim_qk == BLOCK_DMODEL + BLOCK_DPE
mask_dpe
=
offs_dpe
<
dim_qk
qpe
=
tl
.
load
(
q_buffer
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
for
start_indice
in
range
(
0
,
index_topk
,
BLOCK_N
):
offs_indice
=
start_indice
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_indice
=
offs_indice
<
index_topk
indices
=
tl
.
load
(
indices_ptr
+
(
cur_q
*
stride_indices_token
+
cur_kv_head_id
*
stride_indices_head
+
offs_indice
),
mask
=
mask_indice
,
other
=-
1
,
)
mask_kv
=
(
indices
>=
0
)
&
(
indices
<
seq_kv
)
mask_kv_d
=
mask_dmodel
offs_k
=
(
indices
[
None
,
:]
*
stride_k_token
+
cur_kv_head_id
*
stride_k_head
+
offs_d
[:,
None
]
)
# q_nope @ k_nope
k
=
tl
.
load
(
k_buffer
+
offs_k
,
mask
=
(
mask_kv
[
None
,
:])
&
(
mask_kv_d
[:,
None
]),
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
if
BLOCK_DPE
>
0
:
# q_rope @ k_rope
offs_kpe
=
(
indices
[
None
,
:]
*
stride_k_token
+
cur_kv_head_id
*
stride_k_head
+
offs_dpe
[:,
None
]
)
mask_k_dpe
=
offs_dpe
<
dim_qk
kpe
=
tl
.
load
(
k_buffer
+
offs_kpe
,
mask
=
(
mask_kv
[
None
,
:])
&
(
mask_k_dpe
[:,
None
]),
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
.
to
(
q
.
dtype
))
# apply scaling
qk
*=
sm_scale
qk
=
tl
.
where
((
mask_h
[:,
None
])
&
(
mask_kv
[
None
,
:]),
qk
,
-
float
(
"inf"
))
# load v
mask_v_d
=
offs_dv
<
dim_v
offs_v
=
(
indices
[:,
None
]
*
stride_v_token
+
cur_kv_head_id
*
stride_v_head
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
v_buffer
+
offs_v
,
mask
=
(
mask_kv
[:,
None
])
&
(
mask_v_d
[
None
,
:]),
other
=
0.0
)
# online softmax
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
exp2
(
qk
-
n_e_max
[:,
None
])
acc
*=
re_scale
[:,
None
]
# score @ v
acc
+=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
)
# update global sum and max
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
# rescaling
acc
/=
e_sum
[:,
None
]
max_logits
=
e_max
*
LOGE2
# calculate lse
lse
=
max_logits
+
tl
.
log2
(
e_sum
)
*
LOGE2
# write output
offs_o
=
(
cur_q
*
stride_out_token
+
cur_head
[:,
None
]
*
stride_out_head
+
offs_dv
[
None
,
:]
)
mask_out_d
=
offs_dv
<
dim_v
tl
.
store
(
out_ptr
+
offs_o
,
acc
.
to
(
tl
.
bfloat16
),
mask
=
(
mask_h
[:,
None
])
&
(
mask_out_d
[
None
,
:]),
)
offs_lse
=
cur_q
*
stride_lse
+
cur_head
tl
.
store
(
softmax_lse_ptr
+
offs_lse
,
lse
,
mask
=
mask_h
)
tl
.
store
(
max_logits_ptr
+
offs_lse
,
max_logits
,
mask
=
mask_h
)
# reference implementation of bf16 sparse prefill kernel
def
triton_bf16_mla_sparse_interface
(
q
:
torch
.
Tensor
,
# [num_tokens, num_heads_q, dim_qk]
kv
:
torch
.
Tensor
,
# [num_tokens, num_heads_kv, dim_qk]
indices
:
torch
.
Tensor
,
# [num_tokens, num_heads_kv, topk]
sm_scale
:
float
,
d_v
:
int
=
512
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
out : [num_tokens, num_heads_q, d_v]
max_logits : [num_tokens, num_heads_q]
lse : logsumexp, [num_tokens, num_heads_q]
"""
num_tokens
,
num_heads_q
,
dim_qk
=
q
.
shape
_
,
num_heads_kv
,
_
=
kv
.
shape
assert
dim_qk
==
kv
.
shape
[
2
],
"q and kv have different head dimensions"
# for deepseek v3.2, index topk should be 2048
_
,
_
,
index_topk
=
indices
.
shape
BLOCK_H
=
16
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
BLOCK_M
=
32
BLOCK_N
=
16
BLOCK_DV
=
512
assert
d_v
==
BLOCK_DV
,
"only support d_v = 512"
assert
dim_qk
==
BLOCK_DMODEL
+
BLOCK_DPE
,
(
"dim_qk does not match BLOCK_DMODEL + BLOCK_DPE"
)
assert
num_heads_kv
==
1
,
"only support kv head = 1 for now"
assert
index_topk
%
BLOCK_N
==
0
,
"index_topk must be multiple of BLOCK_N"
sm_scale
*=
LOG2E
kv_group_num
=
num_heads_q
//
num_heads_kv
grid
=
(
num_tokens
,
triton
.
cdiv
(
num_heads_q
,
min
(
BLOCK_H
,
kv_group_num
)),
)
out
=
torch
.
zeros
((
num_tokens
,
num_heads_q
,
d_v
),
dtype
=
q
.
dtype
,
device
=
q
.
device
)
softmax_lse
=
torch
.
zeros
(
(
num_tokens
,
num_heads_q
),
dtype
=
torch
.
float32
,
device
=
q
.
device
)
max_logits
=
torch
.
zeros
(
(
num_tokens
,
num_heads_q
),
dtype
=
torch
.
float32
,
device
=
q
.
device
)
k
=
kv
v
=
kv
[...,
:
d_v
]
_bf16_mla_sparse_kernel
[
grid
](
q_buffer
=
q
,
k_buffer
=
k
,
v_buffer
=
v
,
indices_ptr
=
indices
,
out_ptr
=
out
,
softmax_lse_ptr
=
softmax_lse
,
max_logits_ptr
=
max_logits
,
seq_q
=
num_tokens
,
seq_kv
=
kv
.
shape
[
0
],
h_q
=
num_heads_q
,
dim_qk
=
dim_qk
,
dim_v
=
d_v
,
stride_q_token
=
q
.
stride
(
0
),
stride_q_head
=
q
.
stride
(
1
),
stride_k_token
=
k
.
stride
(
0
),
stride_k_head
=
k
.
stride
(
1
),
stride_v_token
=
v
.
stride
(
0
),
stride_v_head
=
v
.
stride
(
1
),
stride_out_token
=
out
.
stride
(
0
),
stride_out_head
=
out
.
stride
(
1
),
stride_lse
=
softmax_lse
.
stride
(
0
),
stride_indices_token
=
indices
.
stride
(
0
),
stride_indices_head
=
indices
.
stride
(
1
),
sm_scale
=
sm_scale
,
kv_group_num
=
kv_group_num
,
index_topk
=
index_topk
,
BLOCK_H
=
BLOCK_H
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
LOGE2
=
LOGE2
,
)
return
out
,
max_logits
,
softmax_lse
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment