Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
21b82f4e
Unverified
Commit
21b82f4e
authored
Nov 07, 2025
by
Jee Jee Li
Committed by
GitHub
Nov 07, 2025
Browse files
[Kernel] LoRA triton kernels support PDL (#27402)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
a736e5ff
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
17 deletions
+68
-17
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
+22
-7
vllm/lora/ops/triton_ops/kernel_utils.py
vllm/lora/ops/triton_ops/kernel_utils.py
+21
-7
vllm/lora/ops/triton_ops/lora_expand_op.py
vllm/lora/ops/triton_ops/lora_expand_op.py
+7
-1
vllm/lora/ops/triton_ops/lora_shrink_op.py
vllm/lora/ops/triton_ops/lora_shrink_op.py
+7
-2
vllm/lora/ops/triton_ops/utils.py
vllm/lora/ops/triton_ops/utils.py
+11
-0
No files found.
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
View file @
21b82f4e
...
@@ -6,6 +6,8 @@ import torch
...
@@ -6,6 +6,8 @@ import torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
.utils
import
supports_pdl
_LORA_PTR_DICT
:
dict
[
tuple
[
int
,
...],
torch
.
tensor
]
=
{}
_LORA_PTR_DICT
:
dict
[
tuple
[
int
,
...],
torch
.
tensor
]
=
{}
...
@@ -82,6 +84,8 @@ def _fused_moe_lora_kernel(
...
@@ -82,6 +84,8 @@ def _fused_moe_lora_kernel(
BLOCK_SIZE_K
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
IS_PRIMARY
:
tl
.
constexpr
,
):
):
pid
=
tl
.
program_id
(
axis
=
0
)
pid
=
tl
.
program_id
(
axis
=
0
)
slice_id
=
tl
.
program_id
(
axis
=
1
)
slice_id
=
tl
.
program_id
(
axis
=
1
)
...
@@ -110,13 +114,11 @@ def _fused_moe_lora_kernel(
...
@@ -110,13 +114,11 @@ def _fused_moe_lora_kernel(
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
+
lora_id
)
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
+
lora_id
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
return
# get the expert_id to process curr shard
# get the expert_id to process curr shard
ind
=
lora_id
*
stride_el
+
pid_m
ind
=
lora_id
*
stride_el
+
pid_m
expert_id
=
tl
.
load
(
expert_ids_ptr
+
ind
,
ind
<
max_loras
*
stride_el
,
-
1
)
expert_id
=
tl
.
load
(
expert_ids_ptr
+
ind
,
ind
<
max_loras
*
stride_el
,
-
1
)
if
expert_id
==
-
1
:
if
expert_id
==
-
1
:
return
return
# get a_ptr,b_ptr,c_ptr
# get a_ptr,b_ptr,c_ptr
cur_a_ptr
=
a_ptr
+
(
slice_id
%
num_slice_a
)
*
slice_a_size
cur_a_ptr
=
a_ptr
+
(
slice_id
%
num_slice_a
)
*
slice_a_size
cur_b_ptr
=
tl
.
load
(
b_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
c_ptr
.
dtype
.
element_ty
))
cur_b_ptr
=
tl
.
load
(
b_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
c_ptr
.
dtype
.
element_ty
))
...
@@ -149,12 +151,17 @@ def _fused_moe_lora_kernel(
...
@@ -149,12 +151,17 @@ def _fused_moe_lora_kernel(
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
grid_k
):
for
k
in
range
(
0
,
grid_k
):
k_remaining
=
K
-
k
*
(
BLOCK_SIZE_K
*
SPLIT_K
)
k_remaining
=
K
-
k
*
(
BLOCK_SIZE_K
*
SPLIT_K
)
# pre-fetch lora weight
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
k_remaining
,
other
=
0.0
)
# GDC wait waits for ALL programs in the the prior kernel to complete
# before continuing.
if
USE_GDC
and
not
IS_PRIMARY
:
tl
.
extra
.
cuda
.
gdc_wait
()
a
=
tl
.
load
(
a
=
tl
.
load
(
a_ptrs
,
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
k_remaining
),
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
k_remaining
),
other
=
0.0
,
other
=
0.0
,
)
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
k_remaining
,
other
=
0.0
)
accumulator
+=
tl
.
dot
(
a
,
b
)
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
SPLIT_K
*
stride_ak
a_ptrs
+=
BLOCK_SIZE_K
*
SPLIT_K
*
stride_ak
...
@@ -163,12 +170,15 @@ def _fused_moe_lora_kernel(
...
@@ -163,12 +170,15 @@ def _fused_moe_lora_kernel(
if
MUL_ROUTED_WEIGHT
:
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
USE_GDC
and
IS_PRIMARY
:
# GDC launch dependents hints the runtime system to launch dependent kernels.
tl
.
extra
.
cuda
.
gdc_launch_dependents
()
accumulator
=
accumulator
.
to
(
c_ptr
.
dtype
.
element_ty
)
accumulator
=
accumulator
.
to
(
c_ptr
.
dtype
.
element_ty
)
# Write back the block of the output
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
cur_c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_ptrs
=
cur_c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
if
SPLIT_K
==
1
:
if
SPLIT_K
==
1
:
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
else
:
else
:
...
@@ -209,7 +219,7 @@ def _fused_moe_lora_shrink(
...
@@ -209,7 +219,7 @@ def _fused_moe_lora_shrink(
mul_routed_weight
:
bool
=
False
,
mul_routed_weight
:
bool
=
False
,
)
->
None
:
)
->
None
:
w1_lora_a_stacked
=
lora_a_stacked
[
0
]
w1_lora_a_stacked
=
lora_a_stacked
[
0
]
use_gdc
=
supports_pdl
(
qcurr_hidden_states
.
device
)
shrink_config
=
{
shrink_config
=
{
"BLOCK_SIZE_M"
:
block_size_m
,
"BLOCK_SIZE_M"
:
block_size_m
,
"BLOCK_SIZE_N"
:
block_size_n
,
"BLOCK_SIZE_N"
:
block_size_n
,
...
@@ -218,6 +228,8 @@ def _fused_moe_lora_shrink(
...
@@ -218,6 +228,8 @@ def _fused_moe_lora_shrink(
"num_warps"
:
num_warps
,
"num_warps"
:
num_warps
,
"num_stages"
:
num_stages
,
"num_stages"
:
num_stages
,
"SPLIT_K"
:
split_k
,
"SPLIT_K"
:
split_k
,
"USE_GDC"
:
use_gdc
,
"launch_pdl"
:
use_gdc
,
# triton kernel metadata
}
}
b_ptr
=
_get_ptr
(
lora_a_stacked
,
device
)
b_ptr
=
_get_ptr
(
lora_a_stacked
,
device
)
...
@@ -229,7 +241,6 @@ def _fused_moe_lora_shrink(
...
@@ -229,7 +241,6 @@ def _fused_moe_lora_shrink(
len
(
lora_a_stacked
),
len
(
lora_a_stacked
),
lora_a_stacked
[
0
].
shape
[
0
],
lora_a_stacked
[
0
].
shape
[
0
],
)
)
_fused_moe_lora_kernel
[
grid
](
_fused_moe_lora_kernel
[
grid
](
qcurr_hidden_states
,
qcurr_hidden_states
,
b_ptr
,
b_ptr
,
...
@@ -261,6 +272,7 @@ def _fused_moe_lora_shrink(
...
@@ -261,6 +272,7 @@ def _fused_moe_lora_shrink(
num_slice_c
=
num_slices
,
num_slice_c
=
num_slices
,
top_k
=
1
if
mul_routed_weight
else
top_k_num
,
top_k
=
1
if
mul_routed_weight
else
top_k_num
,
MUL_ROUTED_WEIGHT
=
False
,
MUL_ROUTED_WEIGHT
=
False
,
IS_PRIMARY
=
True
,
**
shrink_config
,
**
shrink_config
,
)
)
...
@@ -314,7 +326,7 @@ def _fused_moe_lora_expand(
...
@@ -314,7 +326,7 @@ def _fused_moe_lora_expand(
dtype
=
output
.
dtype
,
dtype
=
output
.
dtype
,
device
=
device
,
device
=
device
,
)
)
use_gdc
=
supports_pdl
(
a_intermediate_cache1
.
device
)
expand_config
=
{
expand_config
=
{
"BLOCK_SIZE_M"
:
block_size_m
,
"BLOCK_SIZE_M"
:
block_size_m
,
"BLOCK_SIZE_N"
:
block_size_n
,
"BLOCK_SIZE_N"
:
block_size_n
,
...
@@ -323,6 +335,8 @@ def _fused_moe_lora_expand(
...
@@ -323,6 +335,8 @@ def _fused_moe_lora_expand(
"num_warps"
:
num_warps
,
"num_warps"
:
num_warps
,
"num_stages"
:
num_stages
,
"num_stages"
:
num_stages
,
"SPLIT_K"
:
split_k
,
# Set split_k = 1 for expand calls
"SPLIT_K"
:
split_k
,
# Set split_k = 1 for expand calls
"USE_GDC"
:
use_gdc
,
"launch_pdl"
:
use_gdc
,
# triton kernel metadata
}
}
grid
=
lambda
META
:
(
grid
=
lambda
META
:
(
...
@@ -361,6 +375,7 @@ def _fused_moe_lora_expand(
...
@@ -361,6 +375,7 @@ def _fused_moe_lora_expand(
num_slice_c
=
num_slices
,
num_slice_c
=
num_slices
,
top_k
=
1
,
top_k
=
1
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
IS_PRIMARY
=
False
,
**
expand_config
,
**
expand_config
,
)
)
for
i
in
range
(
num_slices
):
for
i
in
range
(
num_slices
):
...
...
vllm/lora/ops/triton_ops/kernel_utils.py
View file @
21b82f4e
...
@@ -22,6 +22,7 @@ def mm_k(
...
@@ -22,6 +22,7 @@ def mm_k(
SPLIT_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
b_dtype
:
tl
.
constexpr
,
b_dtype
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
):
):
"""
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
...
@@ -45,19 +46,25 @@ def mm_k(
...
@@ -45,19 +46,25 @@ def mm_k(
CAST_TYPE: if True, cast the values from the A matrix to the B
CAST_TYPE: if True, cast the values from the A matrix to the B
matrix dtype.
matrix dtype.
b_dtype: datatype of the B matrix
b_dtype: datatype of the B matrix
USE_GDC: Whether to use PDL. True indicates use.
"""
"""
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)):
for
k
in
range
(
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)):
if
EVEN_K
:
if
EVEN_K
:
tiled_a
=
tl
.
load
(
a_ptr
)
# pre-fetech lora weight
tiled_b
=
tl
.
load
(
b_ptr
)
tiled_b
=
tl
.
load
(
b_ptr
)
if
USE_GDC
:
tl
.
extra
.
cuda
.
gdc_wait
()
tiled_a
=
tl
.
load
(
a_ptr
)
else
:
else
:
tiled_a
=
tl
.
load
(
a_ptr
,
mask
=
offset_k
[
None
,
:]
<
K
-
k
*
(
BLOCK_K
*
SPLIT_K
),
other
=
0
)
tiled_b
=
tl
.
load
(
tiled_b
=
tl
.
load
(
b_ptr
,
mask
=
offset_k
[:,
None
]
<
K
-
k
*
(
BLOCK_K
*
SPLIT_K
),
other
=
0
b_ptr
,
mask
=
offset_k
[:,
None
]
<
K
-
k
*
(
BLOCK_K
*
SPLIT_K
),
other
=
0
)
)
if
USE_GDC
:
tl
.
extra
.
cuda
.
gdc_wait
()
tiled_a
=
tl
.
load
(
a_ptr
,
mask
=
offset_k
[
None
,
:]
<
K
-
k
*
(
BLOCK_K
*
SPLIT_K
),
other
=
0
)
if
CAST_TYPE
:
if
CAST_TYPE
:
tiled_a
=
tiled_a
.
to
(
b_dtype
)
tiled_a
=
tiled_a
.
to
(
b_dtype
)
accumulator
+=
tl
.
dot
(
accumulator
+=
tl
.
dot
(
...
@@ -102,6 +109,7 @@ def do_expand_kernel(
...
@@ -102,6 +109,7 @@ def do_expand_kernel(
EVEN_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
ADD_INPUTS
:
tl
.
constexpr
,
ADD_INPUTS
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
):
):
"""
"""
Given an array of integers that identifies the rows of A, ram,
Given an array of integers that identifies the rows of A, ram,
...
@@ -154,6 +162,7 @@ def do_expand_kernel(
...
@@ -154,6 +162,7 @@ def do_expand_kernel(
# Compute the block matrix product.
# Compute the block matrix product.
SPLIT_K
=
1
SPLIT_K
=
1
accumulator
=
mm_k
(
accumulator
=
mm_k
(
a_ptr
,
a_ptr
,
b_ptr
,
b_ptr
,
...
@@ -168,6 +177,7 @@ def do_expand_kernel(
...
@@ -168,6 +177,7 @@ def do_expand_kernel(
SPLIT_K
,
SPLIT_K
,
CAST_TYPE
,
CAST_TYPE
,
cur_lora_ptr
.
dtype
.
element_ty
,
cur_lora_ptr
.
dtype
.
element_ty
,
USE_GDC
,
)
)
tiled_c
=
accumulator
.
to
(
cur_lora_ptr
.
dtype
.
element_ty
)
tiled_c
=
accumulator
.
to
(
cur_lora_ptr
.
dtype
.
element_ty
)
...
@@ -223,6 +233,7 @@ def do_shrink_kernel(
...
@@ -223,6 +233,7 @@ def do_shrink_kernel(
EVEN_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
):
):
"""
"""
Given an array of integers that identifies the rows of A, ram,
Given an array of integers that identifies the rows of A, ram,
...
@@ -272,8 +283,11 @@ def do_shrink_kernel(
...
@@ -272,8 +283,11 @@ def do_shrink_kernel(
SPLIT_K
,
SPLIT_K
,
False
,
False
,
cur_lora_ptr
.
dtype
.
element_ty
,
cur_lora_ptr
.
dtype
.
element_ty
,
False
,
# USE_GDC is always False in shrink kernel
)
)
# GDC launch dependents hints the runtime system to launch dependent kernels.
if
USE_GDC
:
tl
.
extra
.
cuda
.
gdc_launch_dependents
()
# Identify the C output pointers to store the results of the accumulator.
# Identify the C output pointers to store the results of the accumulator.
offset_cn
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
offset_cn
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
offset_cm
=
tl
.
arange
(
0
,
BLOCK_M
)
offset_cm
=
tl
.
arange
(
0
,
BLOCK_M
)
...
@@ -284,10 +298,10 @@ def do_shrink_kernel(
...
@@ -284,10 +298,10 @@ def do_shrink_kernel(
+
offset_cn
[
None
,
:]
*
output_d2_stride
+
offset_cn
[
None
,
:]
*
output_d2_stride
)
)
c_mask
=
(
offset_cm
[:,
None
]
<
M_LEN
)
&
(
offset_cn
[
None
,
:]
<
N
)
c_mask
=
(
offset_cm
[:,
None
]
<
M_LEN
)
&
(
offset_cn
[
None
,
:]
<
N
)
accumulator
*=
scaling
accumulator
*=
scaling
# handles write-back with reduction-splitting
# handles write-back with reduction-splitting
if
SPLIT_K
==
1
:
if
SPLIT_K
==
1
:
tl
.
store
(
c_ptr
,
accumulator
,
mask
=
c_mask
)
tl
.
store
(
c_ptr
,
accumulator
,
mask
=
c_mask
)
else
:
else
:
tl
.
atomic_add
(
c_ptr
,
accumulator
,
mask
=
c_mask
)
tl
.
atomic_add
(
c_ptr
,
accumulator
,
mask
=
c_mask
,
sem
=
"relaxed"
)
vllm/lora/ops/triton_ops/lora_expand_op.py
View file @
21b82f4e
...
@@ -14,6 +14,8 @@ from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
...
@@ -14,6 +14,8 @@ from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
.utils
import
supports_pdl
@
triton
.
jit
@
triton
.
jit
def
_lora_expand_kernel
(
def
_lora_expand_kernel
(
...
@@ -45,6 +47,7 @@ def _lora_expand_kernel(
...
@@ -45,6 +47,7 @@ def _lora_expand_kernel(
CAST_TYPE
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
SAME_STRIDE
:
tl
.
constexpr
,
SAME_STRIDE
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
):
):
cta_n_num
=
tl
.
cdiv
(
N
,
BLOCK_N
)
cta_n_num
=
tl
.
cdiv
(
N
,
BLOCK_N
)
cta_m_num
=
tl
.
cdiv
(
M
,
BLOCK_M
)
cta_m_num
=
tl
.
cdiv
(
M
,
BLOCK_M
)
...
@@ -121,6 +124,7 @@ def _lora_expand_kernel(
...
@@ -121,6 +124,7 @@ def _lora_expand_kernel(
EVEN_K
,
EVEN_K
,
CAST_TYPE
,
CAST_TYPE
,
ADD_INPUTS
,
ADD_INPUTS
,
USE_GDC
,
)
)
...
@@ -236,7 +240,7 @@ def _lora_expand(
...
@@ -236,7 +240,7 @@ def _lora_expand(
# thread blocks simply exit.
# thread blocks simply exit.
MAX_LORAS
,
MAX_LORAS
,
)
)
use_gdc
=
supports_pdl
(
inputs
.
device
)
_lora_expand_kernel
[
grid
](
_lora_expand_kernel
[
grid
](
inputs
,
inputs
,
lora_ptr_tensor
,
lora_ptr_tensor
,
...
@@ -266,9 +270,11 @@ def _lora_expand(
...
@@ -266,9 +270,11 @@ def _lora_expand(
CAST_TYPE
,
CAST_TYPE
,
NUM_SLICES
,
NUM_SLICES
,
same_stride
,
same_stride
,
use_gdc
,
num_warps
=
NUM_WARPS
,
num_warps
=
NUM_WARPS
,
num_ctas
=
NUM_CTAS
,
num_ctas
=
NUM_CTAS
,
num_stages
=
NUM_STAGES
,
num_stages
=
NUM_STAGES
,
launch_pdl
=
use_gdc
,
)
)
return
return
...
...
vllm/lora/ops/triton_ops/lora_shrink_op.py
View file @
21b82f4e
...
@@ -14,6 +14,8 @@ from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
...
@@ -14,6 +14,8 @@ from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
.utils
import
supports_pdl
@
triton
.
jit
@
triton
.
jit
def
_lora_shrink_kernel
(
def
_lora_shrink_kernel
(
...
@@ -43,6 +45,7 @@ def _lora_shrink_kernel(
...
@@ -43,6 +45,7 @@ def _lora_shrink_kernel(
SPLIT_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
):
):
cta_n_num
=
tl
.
cdiv
(
N
,
BLOCK_N
)
cta_n_num
=
tl
.
cdiv
(
N
,
BLOCK_N
)
cta_m_num
=
tl
.
cdiv
(
M
,
BLOCK_M
)
cta_m_num
=
tl
.
cdiv
(
M
,
BLOCK_M
)
...
@@ -83,7 +86,6 @@ def _lora_shrink_kernel(
...
@@ -83,7 +86,6 @@ def _lora_shrink_kernel(
cta_lora_seq_indices
=
(
cta_lora_seq_indices
=
(
token_indices_sorted_by_lora_ids
+
lora_m_indices_start
+
cta_m_offset
token_indices_sorted_by_lora_ids
+
lora_m_indices_start
+
cta_m_offset
)
)
# Load all relevant row indices.
# Load all relevant row indices.
offset_m
=
tl
.
arange
(
0
,
BLOCK_M
)
%
cta_m_len
offset_m
=
tl
.
arange
(
0
,
BLOCK_M
)
%
cta_m_len
ram
=
tl
.
load
(
cta_lora_seq_indices
+
offset_m
)
ram
=
tl
.
load
(
cta_lora_seq_indices
+
offset_m
)
...
@@ -118,6 +120,7 @@ def _lora_shrink_kernel(
...
@@ -118,6 +120,7 @@ def _lora_shrink_kernel(
EVEN_K
,
EVEN_K
,
SPLIT_K
,
SPLIT_K
,
SLICE_NUM
,
SLICE_NUM
,
USE_GDC
,
)
)
...
@@ -217,7 +220,7 @@ def _lora_shrink(
...
@@ -217,7 +220,7 @@ def _lora_shrink(
# thread blocks exit early.
# thread blocks exit early.
MAX_LORAS
,
MAX_LORAS
,
)
)
use_gdc
=
supports_pdl
(
inputs
.
device
)
_lora_shrink_kernel
[
grid
](
_lora_shrink_kernel
[
grid
](
inputs
,
inputs
,
lora_ptr_tensor
,
lora_ptr_tensor
,
...
@@ -245,9 +248,11 @@ def _lora_shrink(
...
@@ -245,9 +248,11 @@ def _lora_shrink(
SPLIT_K
,
SPLIT_K
,
GROUP_SIZE_M
,
GROUP_SIZE_M
,
NUM_SLICES
,
NUM_SLICES
,
use_gdc
,
num_warps
=
NUM_WARPS
,
num_warps
=
NUM_WARPS
,
num_ctas
=
NUM_CTAS
,
num_ctas
=
NUM_CTAS
,
num_stages
=
NUM_STAGES
,
num_stages
=
NUM_STAGES
,
launch_pdl
=
use_gdc
,
)
)
return
return
...
...
vllm/lora/ops/triton_ops/utils.py
View file @
21b82f4e
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
functools
import
functools
import
json
import
json
from
functools
import
lru_cache
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
from
typing
import
Any
...
@@ -10,6 +11,7 @@ import torch
...
@@ -10,6 +11,7 @@ import torch
from
vllm
import
envs
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -282,3 +284,12 @@ def get_lora_op_configs(
...
@@ -282,3 +284,12 @@ def get_lora_op_configs(
assert
config_data
is
not
None
assert
config_data
is
not
None
return
config_data
return
config_data
@
lru_cache
def
supports_pdl
(
device
:
torch
.
device
|
None
=
None
)
->
bool
:
"""
Refer to: https://github.com/triton-lang/triton/blob/v3.5.0/python/tutorials/11-programmatic-dependent-launch.py
"""
# PDL requires compute capability SM90 or above
return
current_platform
.
is_cuda
()
and
current_platform
.
has_device_capability
(
90
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment