Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e7096898
"docs/vscode:/vscode.git/clone" did not exist on "704432af3c129b7a57fca9b059eefe214159f836"
Commit
e7096898
authored
Mar 02, 2026
by
liuchy5
Browse files
Dsa supported.
parent
1ce0a9a2
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
207 additions
and
51 deletions
+207
-51
csrc/cache.h
csrc/cache.h
+6
-1
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+107
-0
csrc/fused_qknorm_rope_kernel.cu
csrc/fused_qknorm_rope_kernel.cu
+1
-1
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
+11
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+6
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+5
-7
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
...executor/layers/fused_moe/unquantized_fused_moe_method.py
+1
-0
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+3
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+3
-2
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+9
-9
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+10
-9
vllm/v1/attention/ops/flashmla.py
vllm/v1/attention/ops/flashmla.py
+4
-1
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
+33
-16
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+8
-3
No files found.
csrc/cache.h
View file @
e7096898
...
...
@@ -80,7 +80,12 @@ void indexer_k_quant_and_cache(
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
int64_t
quant_block_size
,
// quantization block size
const
std
::
string
&
scale_fmt
);
// Indexer K cache function
void
indexer_k_cache
(
torch
::
Tensor
&
k
,
// [num_tokens, head_dim]
torch
::
Tensor
&
kv_cache
,
// [num_blocks, block_size, cache_stride]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
const
std
::
string
&
scale_fmt
);
// Extract function to gather quantized K cache
void
cp_gather_indexer_k_quant_cache
(
...
...
csrc/cache_kernels.cu
View file @
e7096898
...
...
@@ -600,6 +600,52 @@ __global__ void indexer_k_quant_and_cache_kernel(
reinterpret_cast
<
float
*>
(
kv_cache
)[
dst_scale_idx
/
4
]
=
scale
;
}
}
template
<
typename
scalar_t
,
typename
cache_t
>
__global__
void
indexer_k_cache_kernel
(
const
scalar_t
*
__restrict__
k
,
// [num_tokens, head_dim]
cache_t
*
__restrict__
kv_cache
,
// [num_blocks, block_size, cache_stride]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
head_dim
,
// dimension of each head
const
int
cache_block_size
,
// cache block size
const
int
cache_stride
// stride for each token in kv_cache
)
{
constexpr
int
VEC_SIZE
=
4
;
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
head_dim_idx
=
(
blockIdx
.
y
*
blockDim
.
y
*
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
)
*
VEC_SIZE
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
const
int64_t
block_idx
=
slot_idx
/
cache_block_size
;
const
int64_t
block_offset
=
slot_idx
%
cache_block_size
;
// NOTE: slot_idx can be -1 if the token is padded
if
(
slot_idx
<
0
||
(
head_dim_idx
>=
head_dim
))
{
return
;
}
float2
k_val
=
(
reinterpret_cast
<
const
float2
*>
(
k
))[(
token_idx
*
head_dim
+
head_dim_idx
)
/
VEC_SIZE
];
scalar_t
*
k_val_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
k_val
);
const
int64_t
dst_offset
=
block_idx
*
cache_block_size
*
cache_stride
+
block_offset
*
head_dim
+
head_dim_idx
;
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
i
++
)
{
float
val
=
static_cast
<
float
>
(
k_val_ptr
[
i
]);
if
constexpr
(
std
::
is_same
<
cache_t
,
at
::
Half
>::
value
||
std
::
is_same
<
cache_t
,
__half
>::
value
)
{
kv_cache
[
dst_offset
+
i
]
=
__float2half
(
val
);
}
else
if
constexpr
(
std
::
is_same
<
cache_t
,
at
::
BFloat16
>::
value
||
std
::
is_same
<
cache_t
,
__nv_bfloat16
>::
value
)
{
__hip_bfloat16
bf16_val
=
__float2bfloat16
(
val
);
kv_cache
[
dst_offset
+
i
]
=
*
reinterpret_cast
<
at
::
BFloat16
*>
(
&
bf16_val
);
}
else
if
constexpr
(
std
::
is_same
<
cache_t
,
float
>::
value
)
{
kv_cache
[
dst_offset
+
i
]
=
val
;
}
else
{
kv_cache
[
dst_offset
+
i
]
=
static_cast
<
cache_t
>
(
val
);
}
}
}
template
<
int
BLOCK_Y_SIZE
>
__global__
void
cp_gather_indexer_k_quant_cache_kernel
(
...
...
@@ -1504,3 +1550,64 @@ void cp_gather_indexer_k_quant_cache(
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE
(
32
);
}
}
void
indexer_k_cache
(
torch
::
Tensor
&
k
,
// [num_tokens, head_dim]
torch
::
Tensor
&
kv_cache
,
// [num_blocks, block_size, cache_stride]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
const
std
::
string
&
scale_fmt
)
{
int
num_tokens
=
k
.
size
(
0
);
int
head_dim
=
k
.
size
(
1
);
int
cache_block_size
=
kv_cache
.
size
(
1
);
int
cache_stride
=
kv_cache
.
size
(
2
);
bool
use_ue8m0
=
scale_fmt
==
"ue8m0"
;
TORCH_CHECK
(
k
.
device
()
==
kv_cache
.
device
(),
"k and kv_cache must be on the same device"
);
TORCH_CHECK
(
k
.
device
()
==
slot_mapping
.
device
(),
"k and slot_mapping must be on the same device"
);
constexpr
int
vec_size
=
4
;
dim3
grid
(
num_tokens
,
(
head_dim
+
vec_size
-
1
)
/
vec_size
);
dim3
block
(
32
,
vec_size
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
k
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
k
.
scalar_type
(),
"indexer_k_cache"
,
([
&
]
{
using
k_t
=
scalar_t
;
auto
kv_cache_type
=
kv_cache
.
scalar_type
();
if
(
kv_cache_type
==
at
::
ScalarType
::
Float
)
{
vllm
::
indexer_k_cache_kernel
<
k_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
k
.
data_ptr
<
k_t
>
(),
kv_cache
.
data_ptr
<
float
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
head_dim
,
cache_block_size
,
cache_stride
);
}
else
if
(
kv_cache_type
==
at
::
ScalarType
::
Half
)
{
vllm
::
indexer_k_cache_kernel
<
k_t
,
at
::
Half
>
<<<
grid
,
block
,
0
,
stream
>>>
(
k
.
data_ptr
<
k_t
>
(),
kv_cache
.
data_ptr
<
at
::
Half
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
head_dim
,
cache_block_size
,
cache_stride
);
}
else
if
(
kv_cache_type
==
at
::
ScalarType
::
BFloat16
)
{
vllm
::
indexer_k_cache_kernel
<
k_t
,
at
::
BFloat16
>
<<<
grid
,
block
,
0
,
stream
>>>
(
k
.
data_ptr
<
k_t
>
(),
kv_cache
.
data_ptr
<
at
::
BFloat16
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
head_dim
,
cache_block_size
,
cache_stride
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported kv_cache dtype: "
,
kv_cache
.
dtype
());
}
}));
}
csrc/fused_qknorm_rope_kernel.cu
View file @
e7096898
...
...
@@ -38,7 +38,7 @@
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
//
#if defined(HIP_VERSION) && HIP_VERSION < 70000000
// #if defined(HIP_VERSION) && HIP_VERSION < 70000000
// // On ROCm versions before 7.0, __syncwarp isn't defined. The below
// // implementation is copy/pasted from the implementation in ROCm 7.0
// __device__ inline void __syncwarp() {
...
...
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
View file @
e7096898
...
...
@@ -778,6 +778,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_ds_mla") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
...
...
csrc/torch_bindings.cpp
View file @
e7096898
...
...
@@ -816,6 +816,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"int quant_block_size, str kv_cache_dtype) -> ()"
);
cache_ops
.
impl
(
"indexer_k_quant_and_cache"
,
torch
::
kCUDA
,
&
indexer_k_quant_and_cache
);
cache_ops
.
def
(
"indexer_k_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"str kv_cache_dtype) -> ()"
);
cache_ops
.
impl
(
"indexer_k_cache"
,
torch
::
kCUDA
,
&
indexer_k_cache
);
cache_ops
.
def
(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
...
...
vllm/_custom_ops.py
View file @
e7096898
...
...
@@ -2873,13 +2873,11 @@ def cp_gather_indexer_k_quant_cache(
)
def
indexer_k_quant_and_cache
(
k
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
quant_block_size
:
int
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
indexer_k_quant_and_cache
(
k
,
kv_cache
,
slot_mapping
,
quant_block_size
,
kv_cache_dtype
)
def
indexer_k_cache
(
k
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
indexer_k_cache
(
k
,
kv_cache
,
slot_mapping
,
kv_cache_dtype
)
def
get_device_attribute
(
attribute
:
int
,
device
:
int
)
->
int
:
...
...
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
View file @
e7096898
...
...
@@ -396,6 +396,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
getattr
(
layer
,
"_marlin_w16a16_moe_enabled"
,
False
)
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
e7096898
...
...
@@ -295,8 +295,9 @@ class SparseAttnIndexer(CustomOp):
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
):
if
rocm_aiter_ops
.
is_enabled
():
return
torch
.
ops
.
vllm
.
rocm_aiter_sparse_attn_indexer
(
#if rocm_aiter_ops.is_enabled():
if
current_platform
.
is_rocm
():
return
rocm_aiter_sparse_attn_indexer
(
hidden_states
,
self
.
k_cache
.
prefix
,
self
.
k_cache
.
kv_cache
[
0
],
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
e7096898
...
...
@@ -700,7 +700,7 @@ class Indexer(nn.Module):
q
=
torch
.
cat
([
q_pe
,
q_nope
],
dim
=-
1
)
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
# we only quant q here since k quant is fused with cache insertion
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
q
=
q
.
view
(
-
1
,
self
.
head_dim
)
...
...
@@ -712,7 +712,8 @@ class Indexer(nn.Module):
)
q_fp8
=
q_fp8
.
view
(
-
1
,
self
.
n_head
,
self
.
head_dim
)
q_scale
=
q_scale
.
view
(
-
1
,
self
.
n_head
,
1
)
else
:
q_fp8
=
q
weights
,
_
=
self
.
weights_proj
(
hidden_states
)
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
weights
=
(
...
...
vllm/platforms/rocm.py
View file @
e7096898
# SPDX-License-Identifier: Apache-2.0
#
-l
SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
...
...
@@ -261,15 +261,15 @@ class RocmPlatform(Platform):
kv_cache_dtype
=
attn_selector_config
.
kv_cache_dtype
if
attn_selector_config
.
use_sparse
:
if
kv_cache_dtype
and
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
ValueError
(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
)
assert
block_size
==
1
,
(
"Sparse MLA backend on ROCm only supports block size 1 for now."
)
#
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
#
raise ValueError(
#
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
#
)
#
assert block_size == 1, (
#
"Sparse MLA backend on ROCm only supports block size 1 for now."
#
)
logger
.
info_once
(
"Using Sparse MLA backend."
)
return
AttentionBackendEnum
.
ROCM_AITER_
MLA_SPARSE
.
get_path
()
return
AttentionBackendEnum
.
FLASH
MLA_SPARSE
.
get_path
()
if
attn_selector_config
.
use_mla
:
# if attn_selector_config.use_sparse:
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
e7096898
...
...
@@ -27,6 +27,7 @@ logger = init_logger(__name__)
class
DeepseekV32IndexerBackend
(
AttentionBackend
):
exclude_from_block_size_selection
=
True
@
staticmethod
def
get_name
()
->
str
:
return
"DEEPSEEK_V32_INDEXER"
...
...
@@ -323,15 +324,15 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
requires_padding
=
(
decode_lens_cpu
.
max
()
>
decode_lens_cpu
.
min
()).
item
()
seq_lens
=
common_attn_metadata
.
seq_lens
[:
num_decodes
]
if
is_deep_gemm_supported
():
if
current_platform
.
is_rocm
():
self
.
scheduler_metadata_buffer
=
gemmopt
.
get_paged_mqa_logits_metadata
(
seq_lens
,
self
.
kv_cache_spec
.
block_size
,
self
.
num_sms
)
else
:
self
.
scheduler_metadata_buffer
[:]
=
get_paged_mqa_logits_metadata
(
seq_lens
,
self
.
kv_cache_spec
.
block_size
,
self
.
num_sms
)
#
if is_deep_gemm_supported():
if
current_platform
.
is_rocm
():
self
.
scheduler_metadata_buffer
=
gemmopt
.
get_paged_mqa_logits_metadata
(
seq_lens
,
self
.
kv_cache_spec
.
block_size
,
self
.
num_sms
)
else
:
self
.
scheduler_metadata_buffer
[:]
=
get_paged_mqa_logits_metadata
(
seq_lens
,
self
.
kv_cache_spec
.
block_size
,
self
.
num_sms
)
decode_metadata
=
DeepSeekV32IndexerDecodeMetadata
(
block_table
=
common_attn_metadata
.
block_table_tensor
[:
num_decodes
,
...],
seq_lens
=
common_attn_metadata
.
seq_lens
[:
num_decodes
],
...
...
vllm/v1/attention/ops/flashmla.py
View file @
e7096898
...
...
@@ -31,7 +31,10 @@ else:
if
current_platform
.
is_rocm
():
import
flash_mla.cuda
as
flash_mla_cuda
#from vllm.v1.attention.ops import flashmla
#flash_mla_cuda = flashmla.flash_mla_cuda
from
flash_mla.flash_mla_interface
import
flash_mla_cuda
#import flash_mla.cuda as flash_mla_cuda
_flashmla_C_AVAILABLE
=
True
_flashmla_extension_C_AVAILABLE
=
True
...
...
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
View file @
e7096898
...
...
@@ -505,7 +505,7 @@ def rocm_aiter_sparse_attn_indexer(
total_seq_lens
:
int
,
topk_indices_buffer
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
# careful! this will be None in dummy run
# careful! this will be None in dummy run
attn_metadata
=
get_forward_context
().
attn_metadata
fp8_dtype
=
current_platform
.
fp8_dtype
()
# assert isinstance(attn_metadata, dict)
...
...
@@ -555,21 +555,38 @@ def rocm_aiter_sparse_attn_indexer(
dtype
=
torch
.
uint8
,
)
ops
.
cp_gather_indexer_k_quant_cache
(
kv_cache
,
k_fp8
,
k_scale
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
)
logits
=
rocm_fp8_mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
)),
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
)
ops
.
cp_gather_indexer_k_quant_cache
(
kv_cache
,
k_fp8
,
k_scale
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
)
logits
=
rocm_fp8_mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
)),
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
)
else
:
#k_fp8 = torch.empty(
# [chunk.total_seq_lens, head_dim],
# device=k.device,
# dtype=k.dtype,
#)
logits
=
op
.
mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
k
,
weights
[
chunk
.
token_start
:
chunk
.
token_end
].
to
(
torch
.
float32
),
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
].
shape
[
0
],
k
.
shape
[
0
],
64
,
128
,
True
,
)
num_rows
=
logits
.
shape
[
0
]
assert
topk_tokens
==
2048
,
"top_k_per_row assumes size 2048"
topk_indices
=
topk_indices_buffer
[
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
e7096898
...
...
@@ -5536,7 +5536,9 @@ class GPUModelRunner(
Raises:
ValueError: If no valid block size found
"""
#exclude indexer backend
def
_participates_in_block_size_selection
(
backend
:
type
[
AttentionBackend
])
->
bool
:
return
not
getattr
(
backend
,
"exclude_from_block_size_selection"
,
False
)
def
block_size_is_supported
(
backends
:
list
[
type
[
AttentionBackend
]],
block_size
:
int
)
->
bool
:
...
...
@@ -5557,8 +5559,11 @@ class GPUModelRunner(
if
not
is_supported
:
return
False
return
True
backends
=
[
group
.
backend
for
group
in
attn_groups
]
all_backends
=
[
group
.
backend
for
group
in
attn_groups
]
backends
=
[
b
for
b
in
all_backends
if
_participates_in_block_size_selection
(
b
)
]
# Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment