Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d71496bf
Commit
d71496bf
authored
Mar 03, 2026
by
zhuwenwen
Browse files
support dsa
parent
1ce0a9a2
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
102 additions
and
58 deletions
+102
-58
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
+1
-1
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
...executor/layers/fused_moe/unquantized_fused_moe_method.py
+1
-0
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+67
-38
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+2
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+9
-8
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+10
-9
vllm/v1/attention/ops/flashmla.py
vllm/v1/attention/ops/flashmla.py
+2
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+10
-1
No files found.
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
View file @
d71496bf
...
...
@@ -755,7 +755,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3"
|| KV_DTYPE == "fp8_ds_mla"
) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
...
...
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
View file @
d71496bf
...
...
@@ -396,6 +396,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
getattr
(
layer
,
"_marlin_w16a16_moe_enabled"
,
False
)
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
d71496bf
...
...
@@ -16,6 +16,7 @@ from vllm.v1.attention.backends.mla.indexer import (
)
from
vllm.v1.attention.ops.common
import
pack_seq_triton
,
unpack_seq_triton
from
vllm.v1.worker.workspace
import
current_workspace_manager
from
lightop
import
op
,
gemmopt
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
...
...
@@ -73,6 +74,7 @@ def sparse_attn_indexer(
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
ops
.
indexer_k_quant_and_cache
(
k
,
kv_cache
,
...
...
@@ -86,6 +88,7 @@ def sparse_attn_indexer(
prefill_metadata
=
attn_metadata
.
prefill
# Get the full shared workspace buffers once (will allocate on first use)
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
workspace_manager
=
current_workspace_manager
()
k_fp8_full
,
k_scale_full
=
workspace_manager
.
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
fp8_dtype
),
...
...
@@ -109,6 +112,19 @@ def sparse_attn_indexer(
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
)
else
:
logits
=
op
.
mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
k
,
weights
[
chunk
.
token_start
:
chunk
.
token_end
].
to
(
torch
.
float32
),
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
].
shape
[
0
],
k
.
shape
[
0
],
64
,
128
,
True
,
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[
...
...
@@ -149,6 +165,7 @@ def sparse_attn_indexer(
assert
batch_size
==
decode_metadata
.
seq_lens
.
shape
[
0
]
num_padded_tokens
=
batch_size
*
next_n
if
not
current_platform
.
is_rocm
():
logits
=
fp8_paged_mqa_logits
(
padded_q_fp8_decode_tokens
,
kv_cache
,
...
...
@@ -158,6 +175,17 @@ def sparse_attn_indexer(
decode_metadata
.
schedule_metadata
,
max_model_len
=
max_model_len
,
)
else
:
logits
=
gemmopt
.
paged_mqa_logits
(
padded_q_fp8_decode_tokens
,
kv_cache
,
weights
[:
num_padded_tokens
]
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
else
weights
[:
num_padded_tokens
].
to
(
torch
.
float32
),
decode_metadata
.
seq_lens
,
decode_metadata
.
block_table
,
decode_metadata
.
schedule_metadata
,
max_model_len
,
)
num_rows
=
logits
.
shape
[
0
]
...
...
@@ -258,7 +286,8 @@ class SparseAttnIndexer(CustomOp):
if
current_platform
.
is_cuda
():
return
self
.
forward_cuda
(
hidden_states
,
q_fp8
,
k
,
weights
)
elif
current_platform
.
is_rocm
():
return
self
.
forward_hip
(
hidden_states
,
q_fp8
,
k
,
weights
)
# return self.forward_hip(hidden_states, q_fp8, k, weights)
return
self
.
forward_cuda
(
hidden_states
,
q_fp8
,
k
,
weights
)
else
:
raise
NotImplementedError
(
"SparseAttnIndexer native forward is only implemented for "
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
d71496bf
...
...
@@ -712,6 +712,8 @@ class Indexer(nn.Module):
)
q_fp8
=
q_fp8
.
view
(
-
1
,
self
.
n_head
,
self
.
head_dim
)
q_scale
=
q_scale
.
view
(
-
1
,
self
.
n_head
,
1
)
else
:
q_fp8
=
q
weights
,
_
=
self
.
weights_proj
(
hidden_states
)
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
...
...
vllm/platforms/rocm.py
View file @
d71496bf
...
...
@@ -261,15 +261,16 @@ class RocmPlatform(Platform):
kv_cache_dtype
=
attn_selector_config
.
kv_cache_dtype
if
attn_selector_config
.
use_sparse
:
if
kv_cache_dtype
and
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
ValueError
(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
)
assert
block_size
==
1
,
(
"Sparse MLA backend on ROCm only supports block size 1 for now."
)
#
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
#
raise ValueError(
#
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
#
)
#
assert block_size == 1, (
#
"Sparse MLA backend on ROCm only supports block size 1 for now."
#
)
logger
.
info_once
(
"Using Sparse MLA backend."
)
return
AttentionBackendEnum
.
ROCM_AITER_MLA_SPARSE
.
get_path
()
# return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
return
AttentionBackendEnum
.
FLASHMLA_SPARSE
.
get_path
()
if
attn_selector_config
.
use_mla
:
# if attn_selector_config.use_sparse:
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
d71496bf
...
...
@@ -27,6 +27,7 @@ logger = init_logger(__name__)
class
DeepseekV32IndexerBackend
(
AttentionBackend
):
exclude_from_block_size_selection
=
True
@
staticmethod
def
get_name
()
->
str
:
return
"DEEPSEEK_V32_INDEXER"
...
...
@@ -323,7 +324,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
requires_padding
=
(
decode_lens_cpu
.
max
()
>
decode_lens_cpu
.
min
()).
item
()
seq_lens
=
common_attn_metadata
.
seq_lens
[:
num_decodes
]
if
is_deep_gemm_supported
():
#
if is_deep_gemm_supported():
if
current_platform
.
is_rocm
():
self
.
scheduler_metadata_buffer
=
gemmopt
.
get_paged_mqa_logits_metadata
(
seq_lens
,
self
.
kv_cache_spec
.
block_size
,
self
.
num_sms
...
...
vllm/v1/attention/ops/flashmla.py
View file @
d71496bf
...
...
@@ -31,7 +31,8 @@ else:
if
current_platform
.
is_rocm
():
import
flash_mla.cuda
as
flash_mla_cuda
# import flash_mla.cuda as flash_mla_cuda
from
flash_mla.flash_mla_interface
import
flash_mla_cuda
_flashmla_C_AVAILABLE
=
True
_flashmla_extension_C_AVAILABLE
=
True
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
d71496bf
...
...
@@ -5537,6 +5537,10 @@ class GPUModelRunner(
ValueError: If no valid block size found
"""
#exclude indexer backend
def
_participates_in_block_size_selection
(
backend
:
type
[
AttentionBackend
])
->
bool
:
return
not
getattr
(
backend
,
"exclude_from_block_size_selection"
,
False
)
def
block_size_is_supported
(
backends
:
list
[
type
[
AttentionBackend
]],
block_size
:
int
)
->
bool
:
...
...
@@ -5558,7 +5562,12 @@ class GPUModelRunner(
return
False
return
True
backends
=
[
group
.
backend
for
group
in
attn_groups
]
all_backends
=
[
group
.
backend
for
group
in
attn_groups
]
backends
=
[
b
for
b
in
all_backends
if
_participates_in_block_size_selection
(
b
)
]
# Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment