Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
40e2eeeb
Unverified
Commit
40e2eeeb
authored
Nov 11, 2025
by
caozuoba
Committed by
GitHub
Nov 10, 2025
Browse files
[Kernel] Optimization of the mm_k operator. (#28280)
Co-authored-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
b06b9470
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
18 deletions
+51
-18
vllm/lora/ops/triton_ops/kernel_utils.py
vllm/lora/ops/triton_ops/kernel_utils.py
+51
-18
No files found.
vllm/lora/ops/triton_ops/kernel_utils.py
View file @
40e2eeeb
...
@@ -23,6 +23,7 @@ def mm_k(
...
@@ -23,6 +23,7 @@ def mm_k(
CAST_TYPE
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
b_dtype
:
tl
.
constexpr
,
b_dtype
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
base_k
,
):
):
"""
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
...
@@ -47,32 +48,62 @@ def mm_k(
...
@@ -47,32 +48,62 @@ def mm_k(
matrix dtype.
matrix dtype.
b_dtype: datatype of the B matrix
b_dtype: datatype of the B matrix
USE_GDC: Whether to use PDL. True indicates use.
USE_GDC: Whether to use PDL. True indicates use.
base_k: Base offset along K dimension for current SPLIT_K group
"""
"""
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)):
# Step size along K for each iteration
STEP_K
=
BLOCK_K
*
SPLIT_K
# Total number of iterations (compile-time constant)
num_iters
=
tl
.
cdiv
(
K
,
STEP_K
)
for
k
in
range
(
num_iters
):
# Current iteration's global K offset
iter_k
=
k
*
STEP_K
+
base_k
# Check if this iteration is completely valid (no masking needed)
block_end
=
iter_k
+
BLOCK_K
if
EVEN_K
:
if
EVEN_K
:
# pre-fetech lora weight
# K is divisible by BLOCK_K, no masking ever needed
# pre-fetch lora weight
tiled_b
=
tl
.
load
(
b_ptr
)
tiled_b
=
tl
.
load
(
b_ptr
)
if
USE_GDC
:
if
USE_GDC
:
tl
.
extra
.
cuda
.
gdc_wait
()
tl
.
extra
.
cuda
.
gdc_wait
()
tiled_a
=
tl
.
load
(
a_ptr
)
tiled_a
=
tl
.
load
(
a_ptr
)
if
CAST_TYPE
:
tiled_a
=
tiled_a
.
to
(
b_dtype
)
accumulator
+=
tl
.
dot
(
tiled_a
,
tiled_b
)
else
:
else
:
tiled_b
=
tl
.
load
(
# Check if we need element-wise masking
b_ptr
,
mask
=
offset_k
[:,
None
]
<
K
-
k
*
(
BLOCK_K
*
SPLIT_K
),
other
=
0
if
iter_k
>=
K
:
)
# Entire block out of range, skip
if
USE_GDC
:
pass
tl
.
extra
.
cuda
.
gdc_wait
()
elif
block_end
<=
K
:
tiled_a
=
tl
.
load
(
# Entire block in range, no masking needed (fast path)
a_ptr
,
mask
=
offset_k
[
None
,
:]
<
K
-
k
*
(
BLOCK_K
*
SPLIT_K
),
other
=
0
tiled_b
=
tl
.
load
(
b_ptr
)
)
if
USE_GDC
:
if
CAST_TYPE
:
tl
.
extra
.
cuda
.
gdc_wait
()
tiled_a
=
tiled_a
.
to
(
b_dtype
)
tiled_a
=
tl
.
load
(
a_ptr
)
accumulator
+=
tl
.
dot
(
if
CAST_TYPE
:
tiled_a
,
tiled_a
=
tiled_a
.
to
(
b_dtype
)
tiled_b
,
accumulator
+=
tl
.
dot
(
tiled_a
,
tiled_b
)
)
else
:
a_ptr
+=
BLOCK_K
*
SPLIT_K
*
ak_stride
# Partial block, need masking (only last iteration)
b_ptr
+=
BLOCK_K
*
SPLIT_K
*
bk_stride
k_offsets
=
tl
.
arange
(
0
,
BLOCK_K
)
mask
=
iter_k
+
k_offsets
<
K
tiled_b
=
tl
.
load
(
b_ptr
,
mask
=
mask
[:,
None
],
other
=
0.0
)
if
USE_GDC
:
tl
.
extra
.
cuda
.
gdc_wait
()
tiled_a
=
tl
.
load
(
a_ptr
,
mask
=
mask
[
None
,
:],
other
=
0.0
)
if
CAST_TYPE
:
tiled_a
=
tiled_a
.
to
(
b_dtype
)
accumulator
+=
tl
.
dot
(
tiled_a
,
tiled_b
)
a_ptr
+=
STEP_K
*
ak_stride
b_ptr
+=
STEP_K
*
bk_stride
return
accumulator
return
accumulator
...
@@ -178,6 +209,7 @@ def do_expand_kernel(
...
@@ -178,6 +209,7 @@ def do_expand_kernel(
CAST_TYPE
,
CAST_TYPE
,
cur_lora_ptr
.
dtype
.
element_ty
,
cur_lora_ptr
.
dtype
.
element_ty
,
USE_GDC
,
USE_GDC
,
base_k
=
0
,
)
)
tiled_c
=
accumulator
.
to
(
cur_lora_ptr
.
dtype
.
element_ty
)
tiled_c
=
accumulator
.
to
(
cur_lora_ptr
.
dtype
.
element_ty
)
...
@@ -284,6 +316,7 @@ def do_shrink_kernel(
...
@@ -284,6 +316,7 @@ def do_shrink_kernel(
False
,
False
,
cur_lora_ptr
.
dtype
.
element_ty
,
cur_lora_ptr
.
dtype
.
element_ty
,
False
,
# USE_GDC is always False in shrink kernel
False
,
# USE_GDC is always False in shrink kernel
base_k
=
pid_sk
*
BLOCK_K
,
)
)
# GDC launch dependents hints the runtime system to launch dependent kernels.
# GDC launch dependents hints the runtime system to launch dependent kernels.
if
USE_GDC
:
if
USE_GDC
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment