Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0096798e
Unverified
Commit
0096798e
authored
Sep 09, 2025
by
fzyzcjy
Committed by
GitHub
Sep 08, 2025
Browse files
[1/2] Speed up prefill mla attention (#10156)
parent
2c2b19b1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
130 additions
and
0 deletions
+130
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+2
-0
sgl-kernel/csrc/elementwise/concat_mla.cu
sgl-kernel/csrc/elementwise/concat_mla.cu
+117
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+1
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/elementwise.py
sgl-kernel/python/sgl_kernel/elementwise.py
+8
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
0096798e
...
@@ -259,6 +259,7 @@ set(SOURCES
...
@@ -259,6 +259,7 @@ set(SOURCES
"csrc/elementwise/activation.cu"
"csrc/elementwise/activation.cu"
"csrc/elementwise/cast.cu"
"csrc/elementwise/cast.cu"
"csrc/elementwise/copy.cu"
"csrc/elementwise/copy.cu"
"csrc/elementwise/concat_mla.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/rope.cu"
"csrc/elementwise/rope.cu"
"csrc/common_extension.cc"
"csrc/common_extension.cc"
...
...
sgl-kernel/csrc/common_extension.cc
View file @
0096798e
...
@@ -436,6 +436,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -436,6 +436,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()"
);
m
.
def
(
"copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()"
);
m
.
impl
(
"copy_to_gpu_no_ce"
,
torch
::
kCUDA
,
&
copy_to_gpu_no_ce
);
m
.
impl
(
"copy_to_gpu_no_ce"
,
torch
::
kCUDA
,
&
copy_to_gpu_no_ce
);
m
.
def
(
"concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()"
);
m
.
impl
(
"concat_mla_k"
,
torch
::
kCUDA
,
&
concat_mla_k
);
}
}
REGISTER_EXTENSION
(
common_ops
)
REGISTER_EXTENSION
(
common_ops
)
sgl-kernel/csrc/elementwise/concat_mla.cu
0 → 100644
View file @
0096798e
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
#include <cuda_runtime.h>
#include "pytorch_extension_utils.h"
constexpr
int
NUM_LOCAL_HEADS
=
128
;
constexpr
int
QK_NOPE_HEAD_DIM
=
128
;
constexpr
int
QK_ROPE_HEAD_DIM
=
64
;
constexpr
int
K_HEAD_DIM
=
QK_NOPE_HEAD_DIM
+
QK_ROPE_HEAD_DIM
;
constexpr
int
HEAD_CHUNK_SIZE
=
16
;
constexpr
int
NUM_HEAD_CHUNKS
=
NUM_LOCAL_HEADS
/
HEAD_CHUNK_SIZE
;
__forceinline__
__device__
int
get_lane_id
()
{
int
lane_id
;
asm
(
"mov.s32 %0, %laneid;"
:
"=r"
(
lane_id
));
return
lane_id
;
}
int
ceil_div
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
__global__
void
concat_mla_k_kernel
(
nv_bfloat16
*
k
,
nv_bfloat16
*
k_nope
,
nv_bfloat16
*
k_rope
,
const
int
num_tokens
,
const
int
k_stride_0
,
const
int
k_stride_1
,
const
int
k_nope_stride_0
,
const
int
k_nope_stride_1
,
const
int
k_rope_stride_0
)
{
const
int
flat_warp_id
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
/
32
;
const
int
token_id
=
flat_warp_id
/
NUM_HEAD_CHUNKS
;
const
int
head_chunk_id
=
flat_warp_id
%
NUM_HEAD_CHUNKS
;
const
int
lane_id
=
get_lane_id
();
if
(
token_id
>=
num_tokens
)
{
return
;
}
using
KNopeBufType
=
int2
;
static_assert
(
sizeof
(
KNopeBufType
)
==
QK_NOPE_HEAD_DIM
*
sizeof
(
k
[
0
])
/
32
);
KNopeBufType
k_nope_buf
[
HEAD_CHUNK_SIZE
];
using
KRopeBufType
=
int
;
static_assert
(
sizeof
(
KRopeBufType
)
==
QK_ROPE_HEAD_DIM
*
sizeof
(
k
[
0
])
/
32
);
KRopeBufType
k_rope_buf
;
{
const
int
*
base_addr
=
reinterpret_cast
<
int
*>
(
k_rope
+
token_id
*
k_rope_stride_0
);
k_rope_buf
=
*
(
base_addr
+
lane_id
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
HEAD_CHUNK_SIZE
;
++
i
)
{
const
int
head_id
=
head_chunk_id
*
HEAD_CHUNK_SIZE
+
i
;
const
int2
*
base_addr
=
reinterpret_cast
<
int2
*>
(
k_nope
+
token_id
*
k_nope_stride_0
+
head_id
*
k_nope_stride_1
);
k_nope_buf
[
i
]
=
*
(
base_addr
+
lane_id
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
HEAD_CHUNK_SIZE
;
++
i
)
{
const
int
head_id
=
head_chunk_id
*
HEAD_CHUNK_SIZE
+
i
;
{
int2
*
base_addr
=
reinterpret_cast
<
int2
*>
(
k
+
token_id
*
k_stride_0
+
head_id
*
k_stride_1
);
*
(
base_addr
+
lane_id
)
=
k_nope_buf
[
i
];
}
{
int
*
base_addr
=
reinterpret_cast
<
int
*>
(
k
+
token_id
*
k_stride_0
+
head_id
*
k_stride_1
+
QK_NOPE_HEAD_DIM
);
*
(
base_addr
+
lane_id
)
=
k_rope_buf
;
}
}
}
inline
void
check_tensor
(
const
at
::
Tensor
&
t
,
int64_t
shape0
,
int64_t
shape1
,
int64_t
shape2
,
c10
::
ScalarType
dtype
)
{
TORCH_CHECK_EQ
(
t
.
dim
(),
3
);
TORCH_CHECK_EQ
(
t
.
size
(
0
),
shape0
);
TORCH_CHECK_EQ
(
t
.
size
(
1
),
shape1
);
TORCH_CHECK_EQ
(
t
.
size
(
2
),
shape2
);
TORCH_CHECK_EQ
(
t
.
dtype
(),
dtype
);
TORCH_CHECK
(
t
.
device
().
is_cuda
());
TORCH_CHECK_EQ
(((
int64_t
)
t
.
data_ptr
())
%
16
,
0
);
// alignment
}
void
concat_mla_k
(
at
::
Tensor
k
,
at
::
Tensor
k_nope
,
at
::
Tensor
k_rope
)
{
const
int
num_tokens
=
k
.
size
(
0
);
check_tensor
(
k
,
num_tokens
,
NUM_LOCAL_HEADS
,
K_HEAD_DIM
,
at
::
kBFloat16
);
check_tensor
(
k_nope
,
num_tokens
,
NUM_LOCAL_HEADS
,
QK_NOPE_HEAD_DIM
,
at
::
kBFloat16
);
check_tensor
(
k_rope
,
num_tokens
,
1
,
QK_ROPE_HEAD_DIM
,
at
::
kBFloat16
);
TORCH_CHECK_EQ
(
k
.
stride
(
2
),
1
);
TORCH_CHECK_EQ
(
k_nope
.
stride
(
2
),
1
);
TORCH_CHECK_EQ
(
k_rope
.
stride
(
2
),
1
);
const
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
constexpr
int
num_warps_per_block
=
32
;
const
int
grid_size
=
ceil_div
(
num_tokens
*
NUM_HEAD_CHUNKS
,
num_warps_per_block
);
const
int
block_size
=
num_warps_per_block
*
32
;
concat_mla_k_kernel
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
reinterpret_cast
<
nv_bfloat16
*>
(
k
.
data_ptr
()),
reinterpret_cast
<
nv_bfloat16
*>
(
k_nope
.
data_ptr
()),
reinterpret_cast
<
nv_bfloat16
*>
(
k_rope
.
data_ptr
()),
num_tokens
,
k
.
stride
(
0
),
k
.
stride
(
1
),
k_nope
.
stride
(
0
),
k_nope
.
stride
(
1
),
k_rope
.
stride
(
0
));
cudaError_t
err
=
cudaGetLastError
();
TORCH_CHECK
(
err
==
cudaSuccess
,
"CUDA kernel launch failed: "
,
cudaGetErrorString
(
err
));
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
0096798e
...
@@ -723,3 +723,4 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
...
@@ -723,3 +723,4 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
void
store_kv_cache
(
at
::
Tensor
k_cache
,
at
::
Tensor
v_cache
,
at
::
Tensor
out_loc
,
at
::
Tensor
k
,
at
::
Tensor
v
);
void
store_kv_cache
(
at
::
Tensor
k_cache
,
at
::
Tensor
v_cache
,
at
::
Tensor
out_loc
,
at
::
Tensor
k
,
at
::
Tensor
v
);
void
copy_to_gpu_no_ce
(
const
at
::
Tensor
&
input
,
at
::
Tensor
&
output
);
void
copy_to_gpu_no_ce
(
const
at
::
Tensor
&
input
,
at
::
Tensor
&
output
);
void
concat_mla_k
(
torch
::
Tensor
k
,
torch
::
Tensor
k_nope
,
torch
::
Tensor
k_rope
);
sgl-kernel/python/sgl_kernel/__init__.py
View file @
0096798e
...
@@ -23,6 +23,7 @@ from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_
...
@@ -23,6 +23,7 @@ from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_
from
sgl_kernel.elementwise
import
(
from
sgl_kernel.elementwise
import
(
FusedSetKVBufferArg
,
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
,
apply_rope_with_cos_sin_cache_inplace
,
concat_mla_k
,
copy_to_gpu_no_ce
,
copy_to_gpu_no_ce
,
downcast_fp8
,
downcast_fp8
,
fused_add_rmsnorm
,
fused_add_rmsnorm
,
...
...
sgl-kernel/python/sgl_kernel/elementwise.py
View file @
0096798e
...
@@ -371,3 +371,11 @@ def downcast_fp8(
...
@@ -371,3 +371,11 @@ def downcast_fp8(
def
copy_to_gpu_no_ce
(
input
:
List
[
int
],
output
:
torch
.
Tensor
):
def
copy_to_gpu_no_ce
(
input
:
List
[
int
],
output
:
torch
.
Tensor
):
torch
.
ops
.
sgl_kernel
.
copy_to_gpu_no_ce
(
input
,
output
)
torch
.
ops
.
sgl_kernel
.
copy_to_gpu_no_ce
(
input
,
output
)
def
concat_mla_k
(
k
:
torch
.
Tensor
,
k_nope
:
torch
.
Tensor
,
k_rope
:
torch
.
Tensor
,
):
torch
.
ops
.
sgl_kernel
.
concat_mla_k
(
k
,
k_nope
,
k_rope
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment