Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3b25dc12
Unverified
Commit
3b25dc12
authored
Sep 16, 2025
by
fzyzcjy
Committed by
GitHub
Sep 15, 2025
Browse files
[1/2] Speed up trtllm_mla attention backend (>10% e2e) (#10473)
parent
5c08d7d2
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
119 additions
and
3 deletions
+119
-3
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+3
-0
sgl-kernel/csrc/elementwise/concat_mla.cu
sgl-kernel/csrc/elementwise/concat_mla.cu
+102
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+1
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/elementwise.py
sgl-kernel/python/sgl_kernel/elementwise.py
+12
-0
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+0
-3
No files found.
sgl-kernel/csrc/common_extension.cc
View file @
3b25dc12
...
@@ -104,6 +104,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -104,6 +104,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()"
);
m
.
def
(
"concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()"
);
m
.
impl
(
"concat_mla_k"
,
torch
::
kCUDA
,
&
concat_mla_k
);
m
.
impl
(
"concat_mla_k"
,
torch
::
kCUDA
,
&
concat_mla_k
);
m
.
def
(
"concat_mla_absorb_q(Tensor a, Tensor b, Tensor! out) -> ()"
);
m
.
impl
(
"concat_mla_absorb_q"
,
torch
::
kCUDA
,
&
concat_mla_absorb_q
);
/*
/*
* From csrc/gemm
* From csrc/gemm
*/
*/
...
...
sgl-kernel/csrc/elementwise/concat_mla.cu
View file @
3b25dc12
...
@@ -115,3 +115,105 @@ void concat_mla_k(at::Tensor k, at::Tensor k_nope, at::Tensor k_rope) {
...
@@ -115,3 +115,105 @@ void concat_mla_k(at::Tensor k, at::Tensor k_nope, at::Tensor k_rope) {
cudaError_t
err
=
cudaGetLastError
();
cudaError_t
err
=
cudaGetLastError
();
TORCH_CHECK
(
err
==
cudaSuccess
,
"CUDA kernel launch failed: "
,
cudaGetErrorString
(
err
));
TORCH_CHECK
(
err
==
cudaSuccess
,
"CUDA kernel launch failed: "
,
cudaGetErrorString
(
err
));
}
}
// ============================== concat_mla_absorb_q ==============================
// TODO give a name prefix, also maybe refactor code above
constexpr
int
A_LAST_DIM
=
512
;
constexpr
int
B_LAST_DIM
=
64
;
__global__
void
concat_mla_absorb_q_kernel
(
nv_bfloat16
*
a
,
nv_bfloat16
*
b
,
nv_bfloat16
*
out
,
const
int
num_items
,
const
int
dim_1
,
const
int
a_stride_0
,
const
int
a_stride_1
,
const
int
b_stride_0
,
const
int
b_stride_1
,
const
int
out_stride_0
,
const
int
out_stride_1
)
{
const
int
flat_warp_id
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
/
32
;
const
int
lane_id
=
get_lane_id
();
const
int
idx_0
=
flat_warp_id
/
dim_1
;
const
int
idx_1
=
flat_warp_id
%
dim_1
;
if
(
flat_warp_id
>=
num_items
)
{
return
;
}
using
ABufType
=
int4
;
constexpr
int
A_NUM_UNROLL
=
2
;
static_assert
(
sizeof
(
ABufType
)
*
A_NUM_UNROLL
==
A_LAST_DIM
*
sizeof
(
a
[
0
])
/
32
);
ABufType
a_buf
[
A_NUM_UNROLL
];
using
BBufType
=
int
;
constexpr
int
B_NUM_UNROLL
=
1
;
static_assert
(
sizeof
(
BBufType
)
*
B_NUM_UNROLL
==
B_LAST_DIM
*
sizeof
(
b
[
0
])
/
32
);
BBufType
b_buf
;
{
const
BBufType
*
base_addr
=
reinterpret_cast
<
BBufType
*>
(
b
+
idx_0
*
b_stride_0
+
idx_1
*
b_stride_1
);
b_buf
=
*
(
base_addr
+
lane_id
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
A_NUM_UNROLL
;
++
i
)
{
const
ABufType
*
base_addr
=
reinterpret_cast
<
ABufType
*>
(
a
+
idx_0
*
a_stride_0
+
idx_1
*
a_stride_1
);
a_buf
[
i
]
=
*
(
base_addr
+
i
*
32
+
lane_id
);
}
{
BBufType
*
base_addr
=
reinterpret_cast
<
BBufType
*>
(
out
+
idx_0
*
out_stride_0
+
idx_1
*
out_stride_1
+
A_LAST_DIM
);
*
(
base_addr
+
lane_id
)
=
b_buf
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
A_NUM_UNROLL
;
++
i
)
{
ABufType
*
base_addr
=
reinterpret_cast
<
ABufType
*>
(
out
+
idx_0
*
out_stride_0
+
idx_1
*
out_stride_1
);
*
(
base_addr
+
i
*
32
+
lane_id
)
=
a_buf
[
i
];
}
}
inline
void
check_tensor_concat_mla_absorb_q
(
const
at
::
Tensor
&
t
,
int64_t
shape2
)
{
TORCH_CHECK_EQ
(
t
.
dim
(),
3
);
TORCH_CHECK_EQ
(
t
.
size
(
2
),
shape2
);
TORCH_CHECK_EQ
(
t
.
stride
(
2
),
1
);
TORCH_CHECK_EQ
(
t
.
dtype
(),
at
::
kBFloat16
);
TORCH_CHECK
(
t
.
device
().
is_cuda
());
TORCH_CHECK_EQ
(((
int64_t
)
t
.
data_ptr
())
%
16
,
0
);
// alignment
}
// TODO further optimize it later
void
concat_mla_absorb_q
(
at
::
Tensor
a
,
at
::
Tensor
b
,
at
::
Tensor
out
)
{
check_tensor_concat_mla_absorb_q
(
a
,
A_LAST_DIM
);
check_tensor_concat_mla_absorb_q
(
b
,
B_LAST_DIM
);
check_tensor_concat_mla_absorb_q
(
out
,
A_LAST_DIM
+
B_LAST_DIM
);
const
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK_EQ
(
a
.
size
(
0
)
*
a
.
size
(
1
),
b
.
size
(
0
)
*
b
.
size
(
1
));
TORCH_CHECK_EQ
(
a
.
size
(
1
),
b
.
size
(
1
));
const
int
num_items
=
a
.
size
(
0
)
*
a
.
size
(
1
);
constexpr
int
num_warps_per_block
=
32
;
const
int
grid_size
=
ceil_div
(
num_items
,
num_warps_per_block
);
const
int
block_size
=
num_warps_per_block
*
32
;
concat_mla_absorb_q_kernel
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
reinterpret_cast
<
nv_bfloat16
*>
(
a
.
data_ptr
()),
reinterpret_cast
<
nv_bfloat16
*>
(
b
.
data_ptr
()),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
num_items
,
a
.
size
(
1
),
a
.
stride
(
0
),
a
.
stride
(
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
));
cudaError_t
err
=
cudaGetLastError
();
TORCH_CHECK
(
err
==
cudaSuccess
,
"CUDA kernel launch failed: "
,
cudaGetErrorString
(
err
));
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
3b25dc12
...
@@ -172,6 +172,7 @@ void downcast_fp8(
...
@@ -172,6 +172,7 @@ void downcast_fp8(
void
copy_to_gpu_no_ce
(
const
at
::
Tensor
&
input
,
at
::
Tensor
&
output
);
void
copy_to_gpu_no_ce
(
const
at
::
Tensor
&
input
,
at
::
Tensor
&
output
);
void
concat_mla_k
(
torch
::
Tensor
k
,
torch
::
Tensor
k_nope
,
torch
::
Tensor
k_rope
);
void
concat_mla_k
(
torch
::
Tensor
k
,
torch
::
Tensor
k_nope
,
torch
::
Tensor
k_rope
);
void
concat_mla_absorb_q
(
at
::
Tensor
a
,
at
::
Tensor
b
,
at
::
Tensor
out
);
#ifdef USE_ROCM
#ifdef USE_ROCM
void
gelu_quick
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
input
);
void
gelu_quick
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
input
);
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
3b25dc12
...
@@ -23,6 +23,7 @@ from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_
...
@@ -23,6 +23,7 @@ from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_
from
sgl_kernel.elementwise
import
(
from
sgl_kernel.elementwise
import
(
FusedSetKVBufferArg
,
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
,
apply_rope_with_cos_sin_cache_inplace
,
concat_mla_absorb_q
,
concat_mla_k
,
concat_mla_k
,
copy_to_gpu_no_ce
,
copy_to_gpu_no_ce
,
downcast_fp8
,
downcast_fp8
,
...
...
sgl-kernel/python/sgl_kernel/elementwise.py
View file @
3b25dc12
...
@@ -379,3 +379,15 @@ def concat_mla_k(
...
@@ -379,3 +379,15 @@ def concat_mla_k(
k_rope
:
torch
.
Tensor
,
k_rope
:
torch
.
Tensor
,
):
):
torch
.
ops
.
sgl_kernel
.
concat_mla_k
(
k
,
k_nope
,
k_rope
)
torch
.
ops
.
sgl_kernel
.
concat_mla_k
(
k
,
k_nope
,
k_rope
)
def
concat_mla_absorb_q
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
):
*
batch_dims
,
_
=
a
.
shape
out
=
torch
.
empty
(
(
*
batch_dims
,
a
.
shape
[
-
1
]
+
b
.
shape
[
-
1
]),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
torch
.
ops
.
sgl_kernel
.
concat_mla_absorb_q
(
a
,
b
,
out
)
return
out
test/srt/models/test_generation_models.py
View file @
3b25dc12
...
@@ -67,11 +67,8 @@ ALL_MODELS = [
...
@@ -67,11 +67,8 @@ ALL_MODELS = [
ModelCase
(
"openai-community/gpt2"
),
ModelCase
(
"openai-community/gpt2"
),
ModelCase
(
"microsoft/phi-1_5"
,
trust_remote_code
=
True
),
ModelCase
(
"microsoft/phi-1_5"
,
trust_remote_code
=
True
),
ModelCase
(
"adept/persimmon-8b-chat"
),
ModelCase
(
"adept/persimmon-8b-chat"
),
ModelCase
(
"upstage/SOLAR-10.7B-Instruct-v1.0"
),
ModelCase
(
"upstage/SOLAR-10.7B-Instruct-v1.0"
),
ModelCase
(
"inclusionAI/Ling-lite"
,
trust_remote_code
=
True
),
ModelCase
(
"inclusionAI/Ling-lite"
,
trust_remote_code
=
True
),
ModelCase
(
"microsoft/Phi-3-small-8k-instruct"
,
trust_remote_code
=
True
),
ModelCase
(
"microsoft/Phi-3-small-8k-instruct"
,
trust_remote_code
=
True
),
ModelCase
(
"allenai/OLMo-2-1124-7B-Instruct"
,
skip_long_prompt
=
True
),
ModelCase
(
"allenai/OLMo-2-1124-7B-Instruct"
,
skip_long_prompt
=
True
),
ModelCase
(
"ibm-granite/granite-3.0-2b-instruct"
,
skip_long_prompt
=
True
),
ModelCase
(
"ibm-granite/granite-3.0-2b-instruct"
,
skip_long_prompt
=
True
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment