Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
47b93395
Unverified
Commit
47b93395
authored
Oct 02, 2025
by
Matthew Bonanni
Committed by
GitHub
Oct 02, 2025
Browse files
[DeepSeek] Improve performance of DS MLA cache kernel (#26132)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
5d5146ee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
62 additions
and
68 deletions
+62
-68
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+62
-68
No files found.
csrc/cache_kernels.cu
View file @
47b93395
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
#include <algorithm>
#include <algorithm>
#include <cassert>
#include <cassert>
#include <cfloat> // FLT_MIN
#include <map>
#include <map>
#include <vector>
#include <vector>
...
@@ -424,84 +423,80 @@ __global__ void concat_and_cache_ds_mla_kernel(
...
@@ -424,84 +423,80 @@ __global__ void concat_and_cache_ds_mla_kernel(
const
int64_t
dst_idx_start
=
const
int64_t
dst_idx_start
=
block_idx
*
block_stride
+
block_offset
*
entry_stride
;
block_idx
*
block_stride
+
block_offset
*
entry_stride
;
// Create 4 tile scales in shared memory
// For the NoPE part, each tile of 128 elements is handled by half of one warp
__shared__
float
smem
[
20
];
// (16 threads). There are 4 total tiles, so 2 warps (64 threads).
float
*
shard_abs_max
=
smem
;
// Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
float
*
tile_scales
=
smem
+
16
;
// The RoPE part (last 64 elements) is handled by another 1 warp (32 threads).
// So in total, we use 3 warps (96 threads) per block.
// For the NoPE part, each tile of 128 elements is handled by 4 warps
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
// The first thread of the first warp in each tile writes the scale
// value for the tile. The RoPE part (last 64 elements) is handled
// by another 2 warps (64 threads).
// So in total, we use 18 warps (576 threads) per block.
// Cast kv_cache to 16_bit for RoPE values
// Cast kv_cache to 16_bit for RoPE values
scalar_t
*
kv_cache_16bit
=
scalar_t
*
kv_cache_16bit
=
reinterpret_cast
<
scalar_t
*>
(
&
kv_cache
[
dst_idx_start
]);
reinterpret_cast
<
scalar_t
*>
(
&
kv_cache
[
dst_idx_start
]);
// The last 64 threads handle the RoPE part
// The last warp handles the RoPE part
if
(
threadIdx
.
x
>=
kv_lora_rank
)
{
if
(
threadIdx
.
x
>=
64
)
{
const
int8_t
pe_idx
=
threadIdx
.
x
-
kv_lora_rank
;
// Each thread handles two elements of RoPE
const
int64_t
src_idx
=
token_idx
*
k_pe_stride
+
pe_idx
;
const
int8_t
pe_idx_start
=
(
threadIdx
.
x
-
64
)
*
2
;
const
int64_t
src_idx
=
token_idx
*
k_pe_stride
+
pe_idx_start
;
// Vectorized load of two 16-bit values, performed as one 32-bit load
const
int32_t
vals
=
*
reinterpret_cast
<
const
int32_t
*>
(
&
k_pe
[
src_idx
]);
// RoPE values start after the packed 8-bit NoPE values and the
// RoPE values start after the packed 8-bit NoPE values and the
// 32-bit scales
// 32-bit scales
const
int64_t
dst_idx
=
kv_lora_rank
/
2
+
8
+
pe_idx
;
const
int64_t
dst_idx
=
kv_lora_rank
/
2
+
8
+
pe_idx_start
;
kv_cache_16bit
[
dst_idx
]
=
k_pe
[
src_idx
];
// Vectorized store of two 16-bit values, performed as one 32-bit store
*
reinterpret_cast
<
int32_t
*>
(
&
kv_cache_16bit
[
dst_idx
])
=
vals
;
return
;
return
;
}
}
// Determine the scale for each chunk of NoPE
// The first two warps handle the NoPE part
const
int16_t
tile_idx
=
threadIdx
.
x
>>
7
;
const
int8_t
warp_idx
=
threadIdx
.
x
>>
5
;
const
int16_t
warp_idx
=
(
threadIdx
.
x
&
127
)
>>
5
;
const
int8_t
lane_idx
=
threadIdx
.
x
&
31
;
const
int16_t
lane_idx
=
threadIdx
.
x
&
31
;
const
int8_t
tile_idx
=
warp_idx
*
2
+
(
lane_idx
>>
4
);
// Load the NoPE element for this thread into registers
// Each thread handles 8 elements of NoPE
const
int64_t
src_idx
=
token_idx
*
kv_c_stride
+
threadIdx
.
x
;
// Load the NoPE elements for this thread into registers
const
scalar_t
src_val
=
kv_c
[
src_idx
];
const
int64_t
src_idx_start
=
token_idx
*
kv_c_stride
+
(
threadIdx
.
x
*
8
);
// Vectorized load of eight 16-bit values, performed as an int4 load
// Warp-level reduction to find the max absolute value in the warp
const
int4
vals_i4
=
*
reinterpret_cast
<
const
int4
*>
(
&
kv_c
[
src_idx_start
]);
float
max_abs
=
fabsf
(
src_val
);
const
scalar_t
*
vals
=
reinterpret_cast
<
const
scalar_t
*>
(
&
vals_i4
);
// Max absolute value of this thread's elements
float
max_abs
=
fmaxf
(
fmaxf
(
fmaxf
(
fabsf
(
vals
[
0
]),
fabsf
(
vals
[
1
])),
fmaxf
(
fabsf
(
vals
[
2
]),
fabsf
(
vals
[
3
]))),
fmaxf
(
fmaxf
(
fabsf
(
vals
[
4
]),
fabsf
(
vals
[
5
])),
fmaxf
(
fabsf
(
vals
[
6
]),
fabsf
(
vals
[
7
]))));
// Warp-level reduction to find the max absolute value in each half-warp
#pragma unroll
#pragma unroll
for
(
int
offset
=
16
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
8
;
offset
>
0
;
offset
/=
2
)
{
#ifdef USE_ROCM
max_abs
=
fmaxf
(
max_abs
,
VLLM_SHFL_XOR_SYNC_WIDTH
(
max_abs
,
offset
,
16
));
max_abs
=
fmaxf
(
max_abs
,
__shfl_down_sync
(
UINT64_MAX
,
max_abs
,
offset
));
#else
max_abs
=
fmaxf
(
max_abs
,
__shfl_down_sync
(
0xFFFFFFFF
,
max_abs
,
offset
));
#endif
}
}
// The first lane of each warp in each tile writes the max_abs of this part
// Compute the scale for the tile
// of the tile to shared memory
float
tile_scale
=
max_abs
/
448.
f
;
if
(
lane_idx
==
0
)
{
shard_abs_max
[
tile_idx
*
4
+
warp_idx
]
=
max_abs
;
// The first lane of each half-warp writes the scale to kv_cache
}
if
((
lane_idx
==
0
)
||
(
lane_idx
==
16
))
{
__syncthreads
();
// The first lane of the first warp in each tile computes the scale for the
// tile and writes it to shared memory and to kv_cache
if
(
warp_idx
==
0
&&
lane_idx
==
0
)
{
float4
shard_abs_max_vec
=
reinterpret_cast
<
float4
*>
(
shard_abs_max
)[
tile_idx
];
float
tile_scale
=
fmaxf
(
fmaxf
(
shard_abs_max_vec
.
x
,
shard_abs_max_vec
.
y
),
fmaxf
(
shard_abs_max_vec
.
z
,
shard_abs_max_vec
.
w
))
/
448.
f
;
// Avoid division by zero in `scaled_convert`
tile_scales
[
tile_idx
]
=
fmaxf
(
tile_scale
,
FLT_MIN
);
float
*
kv_cache_32bit
=
reinterpret_cast
<
float
*>
(
&
kv_cache
[
dst_idx_start
]);
float
*
kv_cache_32bit
=
reinterpret_cast
<
float
*>
(
&
kv_cache
[
dst_idx_start
]);
const
uint64_t
dst_idx
=
kv_lora_rank
/
4
+
tile_idx
;
const
uint64_t
dst_idx
=
kv_lora_rank
/
4
+
tile_idx
;
kv_cache_32bit
[
dst_idx
]
=
tile_scale
s
[
tile_idx
]
;
kv_cache_32bit
[
dst_idx
]
=
tile_scale
;
}
}
__syncthreads
();
// Now all threads in the block scale and write their elements
// NoPE data is packed in the first kv_lora_rank/2 bytes (first 256 bytes)
const
int64_t
dst_idx_base
=
dst_idx_start
+
(
threadIdx
.
x
*
8
);
uint8_t
result
[
8
];
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
result
[
i
]
=
fp8
::
scaled_convert
<
uint8_t
,
scalar_t
,
Fp8KVCacheDataType
::
kFp8E4M3
>
(
vals
[
i
],
tile_scale
);
}
// Now all threads in the block scale and write their element
// Store as aligned 64-bit writes
const
float
scale_val
=
tile_scales
[
tile_idx
];
*
reinterpret_cast
<
uint64_t
*>
(
&
kv_cache
[
dst_idx_base
])
=
const
int64_t
dst_idx
=
dst_idx_start
+
threadIdx
.
x
;
*
reinterpret_cast
<
const
uint64_t
*>
(
result
);
kv_cache
[
dst_idx
]
=
fp8
::
scaled_convert
<
uint8_t
,
scalar_t
,
Fp8KVCacheDataType
::
kFp8E4M3
>
(
src_val
,
scale_val
);
}
}
template
<
typename
scalar_t
,
typename
cache_t
,
Fp8KVCacheDataType
kv_dt
>
template
<
typename
scalar_t
,
typename
cache_t
,
Fp8KVCacheDataType
kv_dt
>
...
@@ -741,13 +736,12 @@ void concat_and_cache_mla(
...
@@ -741,13 +736,12 @@ void concat_and_cache_mla(
if
(
kv_cache_dtype
==
"fp8_ds_mla"
)
{
if
(
kv_cache_dtype
==
"fp8_ds_mla"
)
{
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
// For the NoPE part, each tile of 128 elements is handled by 4 warps
// For the NoPE part, each tile of 128 elements is handled by half of one
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
// warp (16 threads). There are 4 total tiles, so 2 warps (64 threads).
// The first thread of the first warp in each tile writes the scale
// Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
// value for the tile. The RoPE part (last 64 elements) is handled
// The RoPE part (last 64 elements) is handled by another 1 warp (32
// by another 2 warps (64 threads).
// threads). So in total, we use 3 warps (96 threads) per block.
// So in total, we use 18 warps (576 threads) per block.
dim3
block
(
96
);
dim3
block
(
576
);
DISPATCH_BY_KV_CACHE_DTYPE
(
kv_c
.
dtype
(),
kv_cache_dtype
,
DISPATCH_BY_KV_CACHE_DTYPE
(
kv_c
.
dtype
(),
kv_cache_dtype
,
CALL_CONCAT_AND_CACHE_DS_MLA
);
CALL_CONCAT_AND_CACHE_DS_MLA
);
}
else
{
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment