Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b3ce711b
Unverified
Commit
b3ce711b
authored
Mar 13, 2026
by
yugong333
Committed by
GitHub
Mar 13, 2026
Browse files
Fp8 lora dense kernel (#35242)
Signed-off-by:
Yu Gong
<
yu3.gong@gmail.com
>
parent
abf61aaa
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
2439 additions
and
1 deletion
+2439
-1
tests/lora/test_punica_ops_fp8.py
tests/lora/test_punica_ops_fp8.py
+999
-0
vllm/lora/ops/triton_ops/__init__.py
vllm/lora/ops/triton_ops/__init__.py
+4
-0
vllm/lora/ops/triton_ops/fp8_kernel_utils.py
vllm/lora/ops/triton_ops/fp8_kernel_utils.py
+603
-0
vllm/lora/ops/triton_ops/lora_expand_fp8_op.py
vllm/lora/ops/triton_ops/lora_expand_fp8_op.py
+403
-0
vllm/lora/ops/triton_ops/lora_shrink_fp8_op.py
vllm/lora/ops/triton_ops/lora_shrink_fp8_op.py
+429
-0
vllm/lora/ops/triton_ops/utils.py
vllm/lora/ops/triton_ops/utils.py
+1
-1
No files found.
tests/lora/test_punica_ops_fp8.py
0 → 100644
View file @
b3ce711b
This diff is collapsed.
Click to expand it.
vllm/lora/ops/triton_ops/__init__.py
View file @
b3ce711b
...
@@ -12,13 +12,17 @@ from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
...
@@ -12,13 +12,17 @@ from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
fused_moe_lora_expand
,
fused_moe_lora_expand
,
fused_moe_lora_shrink
,
fused_moe_lora_shrink
,
)
)
from
vllm.lora.ops.triton_ops.lora_expand_fp8_op
import
lora_expand_fp8
from
vllm.lora.ops.triton_ops.lora_expand_op
import
lora_expand
from
vllm.lora.ops.triton_ops.lora_expand_op
import
lora_expand
from
vllm.lora.ops.triton_ops.lora_kernel_metadata
import
LoRAKernelMeta
from
vllm.lora.ops.triton_ops.lora_kernel_metadata
import
LoRAKernelMeta
from
vllm.lora.ops.triton_ops.lora_shrink_fp8_op
import
lora_shrink_fp8
from
vllm.lora.ops.triton_ops.lora_shrink_op
import
lora_shrink
from
vllm.lora.ops.triton_ops.lora_shrink_op
import
lora_shrink
__all__
=
[
__all__
=
[
"lora_expand"
,
"lora_expand"
,
"lora_expand_fp8"
,
"lora_shrink"
,
"lora_shrink"
,
"lora_shrink_fp8"
,
"LoRAKernelMeta"
,
"LoRAKernelMeta"
,
"fused_moe_lora"
,
"fused_moe_lora"
,
"fused_moe_lora_shrink"
,
"fused_moe_lora_shrink"
,
...
...
vllm/lora/ops/triton_ops/fp8_kernel_utils.py
0 → 100644
View file @
b3ce711b
This diff is collapsed.
Click to expand it.
vllm/lora/ops/triton_ops/lora_expand_fp8_op.py
0 → 100644
View file @
b3ce711b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import
torch
from
vllm.lora.ops.triton_ops.fp8_kernel_utils
import
do_expand_kernel_fp8
from
vllm.lora.ops.triton_ops.utils
import
(
_get_lora_b_ptr
,
get_lora_op_configs
,
)
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
direct_register_custom_op
_EXPAND_LORA_SCALE_PTR_DICT
:
dict
[
tuple
[
int
,
...],
torch
.
tensor
]
=
{}
def
_get_expand_lora_scale_ptr
(
lora_weights
:
list
[
torch
.
Tensor
],
device
:
torch
.
device
):
"""
`_EXPAND_LORA_SCALE_PTR_DICT` collects the required information during
`profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key
=
tuple
(
lora_weight
.
data_ptr
()
for
lora_weight
in
lora_weights
)
if
(
ptr_tensor
:
=
_EXPAND_LORA_SCALE_PTR_DICT
.
get
(
key
))
is
not
None
:
return
ptr_tensor
if
len
(
lora_weights
)
>
1
:
tensor_ptrs
=
[]
for
lora_weight
in
lora_weights
:
tensor_ptrs
.
append
(
lora_weight
.
data_ptr
())
ptr_tensor
=
torch
.
tensor
(
tensor_ptrs
,
device
=
device
,
dtype
=
torch
.
uint64
)
else
:
# Single slice: return the actual tensor so the kernel can use it
# directly without pointer indirection (matches SLICE_NUM == 1 path).
ptr_tensor
=
lora_weights
[
0
]
_EXPAND_LORA_SCALE_PTR_DICT
[
key
]
=
ptr_tensor
return
_EXPAND_LORA_SCALE_PTR_DICT
.
get
(
key
)
@
triton
.
jit
def
_lora_expand_kernel_fp8
(
input_ptr
,
lora_ptr
,
out_ptr
,
a_scale_ptr
,
b_scale_ptr
,
M
,
N
,
K
,
token_indices_sorted_by_lora_ids
,
num_tokens_per_lora
,
lora_token_start_loc
,
lora_ids
,
slice_start_loc
,
input_d0_stride
,
input_d1_stride
,
input_d2_stride
,
ls_d0_ptr
,
ls_d1_ptr
,
ls_d2_ptr
,
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
output_d0_stride
,
output_d1_stride
,
output_hs_ptr
,
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ADD_INPUTS
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
SAME_STRIDE
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
launch_pdl
:
tl
.
constexpr
,
):
"""
FP8-compatible expand kernel wrapper.
"""
cta_n_num
=
tl
.
cdiv
(
N
,
BLOCK_N
)
cta_m_num
=
tl
.
cdiv
(
M
,
BLOCK_M
)
pid_mn
=
tl
.
program_id
(
axis
=
0
)
pid_m
=
pid_mn
%
cta_m_num
pid_n
=
(
pid_mn
//
cta_m_num
)
%
cta_n_num
slice_id
=
tl
.
program_id
(
axis
=
1
)
lora_idx
=
tl
.
program_id
(
axis
=
2
)
lora_id
=
tl
.
load
(
lora_ids
+
lora_idx
)
if
lora_id
==
-
1
:
return
lora_m_size
=
tl
.
load
(
num_tokens_per_lora
+
lora_idx
)
cta_m_offset
=
pid_m
*
BLOCK_M
if
cta_m_offset
>=
lora_m_size
:
return
curr_N
=
N
if
SAME_STRIDE
else
tl
.
load
(
output_hs_ptr
+
slice_id
)
if
pid_n
*
BLOCK_N
>=
curr_N
:
return
cta_m_len
=
min
(
BLOCK_M
,
lora_m_size
-
cta_m_offset
)
lora_m_indices_start
=
tl
.
load
(
lora_token_start_loc
+
lora_idx
)
cta_lora_seq_indices
=
(
token_indices_sorted_by_lora_ids
+
lora_m_indices_start
+
cta_m_offset
)
offset_m
=
tl
.
arange
(
0
,
BLOCK_M
)
%
cta_m_len
ram
=
tl
.
load
(
cta_lora_seq_indices
+
offset_m
)
do_expand_kernel_fp8
(
pid_n
,
lora_id
,
slice_id
,
input_ptr
,
lora_ptr
,
out_ptr
,
a_scale_ptr
,
b_scale_ptr
,
curr_N
,
K
,
cta_m_len
,
ram
,
slice_start_loc
,
input_d0_stride
,
input_d1_stride
,
input_d2_stride
,
ls_d0_ptr
,
ls_d1_ptr
,
ls_d2_ptr
,
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
output_d0_stride
,
output_d1_stride
,
group_n
,
group_k
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
SAME_STRIDE
,
SLICE_NUM
,
EVEN_K
,
CAST_TYPE
,
ADD_INPUTS
,
USE_GDC
,
use_fp8_w8a8
,
per_channel_quant
,
)
@
torch
.
inference_mode
()
def
_lora_expand_fp8
(
inputs
:
torch
.
Tensor
,
# shape [num_slices, num_tokens, lora_rank]
lora_b_weights
:
list
[
torch
.
Tensor
],
# FP8 [num_lora, hidden_size, lora_rank]
output_tensor
:
torch
.
Tensor
,
# shape [num_tokens, hidden_size * num_slices]
token_lora_mapping
:
torch
.
Tensor
,
token_indices_sorted_by_lora_ids
:
torch
.
Tensor
,
num_tokens_per_lora
:
torch
.
Tensor
,
lora_token_start_loc
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
no_lora_flag_cpu
:
torch
.
Tensor
,
# shape [1]
num_active_loras
:
int
,
# number of active LoRAs (unused here, for API compat)
b_scale
:
list
[
torch
.
Tensor
],
# LoRA B weight scale per slice
a_scale
:
torch
.
Tensor
|
None
=
None
,
# Scale for shrink output (optional)
offset_start
:
int
=
0
,
add_inputs
:
bool
=
False
,
group_k
:
int
=
0
,
group_n
:
int
=
0
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
)
->
None
:
"""
FP8-compatible LoRA expand operation.
Args:
inputs: Input tensor from shrink operation [num_slices, num_tokens, lora_rank]
lora_b_weights: List of FP8 LoRA B weights per slice
output_tensor: Output tensor
a_scale: Optional scale for input (if input is quantized)
b_scale: Weight quantization scales per slice
token_lora_mapping: Token to LoRA ID mapping
token_indices_sorted_by_lora_ids: Sorted token indices
num_tokens_per_lora: Number of tokens per LoRA
lora_token_start_loc: Start location for each LoRA's tokens
lora_ids: LoRA IDs to process
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
group_k (int, optional): Block size for K in block-wise quantization.
group_n (int, optional): Block size for N in block-wise quantization.
use_fp8_w8a8 (bool, optional): Whether to use FP8 W8A8 quantization.
per_channel_quant (bool, optional): Whether to use per-channel quantization.
"""
assert
no_lora_flag_cpu
.
numel
()
==
1
if
no_lora_flag_cpu
.
item
():
# None of the inputs require LoRA.
return
if
use_fp8_w8a8
:
assert
inputs
.
dtype
in
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
,
]
for
weight
in
lora_b_weights
:
assert
weight
.
dtype
in
[
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
,
]
else
:
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
for
weight
in
lora_b_weights
:
assert
weight
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
inputs
.
size
(
0
)
==
len
(
lora_b_weights
)
assert
output_tensor
.
is_contiguous
()
# metadata sanity check.
M
=
inputs
.
size
(
1
)
assert
token_lora_mapping
.
size
(
0
)
==
M
assert
token_lora_mapping
.
size
(
0
)
==
token_indices_sorted_by_lora_ids
.
size
(
0
)
assert
lora_ids
.
size
(
0
)
==
num_tokens_per_lora
.
size
(
0
)
assert
lora_token_start_loc
.
size
(
0
)
==
lora_ids
.
size
(
0
)
+
1
(
slice_start_tensor
,
lora_ptr_tensor
,
lora_strides_d0_tensor
,
lora_strides_d1_tensor
,
lora_strides_d2_tensor
,
hidden_sizes_tensor
,
same_stride
,
MAX_N
,
)
=
_get_lora_b_ptr
(
lora_b_weights
,
offset_start
,
inputs
.
device
)
# Get scale pointers
if
b_scale
is
not
None
:
b_scale_ptr_tensor
=
_get_expand_lora_scale_ptr
(
b_scale
,
inputs
.
device
)
else
:
b_scale_ptr_tensor
=
None
K
=
lora_b_weights
[
0
].
shape
[
-
1
]
ADD_INPUTS
=
add_inputs
MAX_LORAS
=
lora_ids
.
size
(
0
)
CAST_TYPE
=
False
NUM_SLICES
=
len
(
lora_b_weights
)
# Triton kernel configs.
kernel_config
=
get_lora_op_configs
(
op_type
=
"expand"
,
max_loras
=
MAX_LORAS
,
batch
=
M
,
hidden_size
=
MAX_N
,
rank
=
K
,
num_slices
=
NUM_SLICES
,
add_inputs
=
add_inputs
,
)
BLOCK_M
=
kernel_config
[
"block_m"
]
BLOCK_N
=
kernel_config
[
"block_n"
]
BLOCK_K
=
kernel_config
[
"block_k"
]
NUM_WARPS
=
kernel_config
[
"num_warps"
]
NUM_CTAS
=
kernel_config
.
get
(
"num_ctas"
,
1
)
NUM_STAGES
=
kernel_config
[
"num_stages"
]
EVEN_K
=
K
%
BLOCK_K
==
0
grid
=
(
triton
.
cdiv
(
M
,
BLOCK_M
)
*
triton
.
cdiv
(
MAX_N
,
BLOCK_N
),
NUM_SLICES
,
num_active_loras
,
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
use_gdc
=
False
# supports_pdl(inputs.device)
# Get scale strides
if
a_scale
is
not
None
:
a_scale_m_stride
=
a_scale
.
stride
(
0
)
if
a_scale
.
dim
()
>
1
else
0
a_scale_k_stride
=
a_scale
.
stride
(
-
1
)
if
a_scale
.
dim
()
>
1
else
0
else
:
a_scale_m_stride
=
0
a_scale_k_stride
=
0
if
b_scale
is
not
None
and
b_scale
[
0
].
dim
()
>
0
:
b_scale_l_stride
=
b_scale
[
0
].
stride
(
0
)
if
b_scale
[
0
].
dim
()
>
0
else
0
b_scale_n_stride
=
(
b_scale
[
0
].
stride
(
-
2
)
if
b_scale
[
0
].
dim
()
>
2
else
(
b_scale
[
0
].
stride
(
-
1
)
if
b_scale
[
0
].
dim
()
>
1
else
1
)
)
b_scale_k_stride
=
b_scale
[
0
].
stride
(
-
1
)
if
b_scale
[
0
].
dim
()
>
2
else
0
else
:
b_scale_l_stride
=
1
b_scale_n_stride
=
0
b_scale_k_stride
=
0
_lora_expand_kernel_fp8
[
grid
](
inputs
,
lora_ptr_tensor
,
output_tensor
,
a_scale
,
b_scale_ptr_tensor
,
M
,
MAX_N
,
K
,
token_indices_sorted_by_lora_ids
,
num_tokens_per_lora
,
lora_token_start_loc
,
lora_ids
,
slice_start_tensor
,
inputs
.
stride
(
0
),
inputs
.
stride
(
1
),
inputs
.
stride
(
2
),
lora_strides_d0_tensor
,
lora_strides_d1_tensor
,
lora_strides_d2_tensor
,
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
hidden_sizes_tensor
,
group_n
,
group_k
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
ADD_INPUTS
,
CAST_TYPE
,
NUM_SLICES
,
same_stride
,
use_gdc
,
use_fp8_w8a8
=
use_fp8_w8a8
,
per_channel_quant
=
per_channel_quant
,
num_warps
=
NUM_WARPS
,
num_ctas
=
NUM_CTAS
,
num_stages
=
NUM_STAGES
,
launch_pdl
=
use_gdc
,
)
return
def
_lora_expand_fp8_fake
(
inputs
:
torch
.
Tensor
,
lora_b_weights
:
list
[
torch
.
Tensor
],
output_tensor
:
torch
.
Tensor
,
token_lora_mapping
:
torch
.
Tensor
,
token_indices_sorted_by_lora_ids
:
torch
.
Tensor
,
num_tokens_per_lora
:
torch
.
Tensor
,
lora_token_start_loc
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
no_lora_flag_cpu
:
torch
.
Tensor
,
num_active_loras
:
int
,
b_scale
:
list
[
torch
.
Tensor
],
a_scale
:
torch
.
Tensor
|
None
=
None
,
offset_start
:
int
=
0
,
add_inputs
:
bool
=
False
,
group_k
:
int
=
0
,
group_n
:
int
=
0
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
)
->
None
:
return
try
:
direct_register_custom_op
(
op_name
=
"lora_expand_fp8"
,
op_func
=
_lora_expand_fp8
,
mutates_args
=
[
"output_tensor"
],
fake_impl
=
_lora_expand_fp8_fake
,
)
lora_expand_fp8
=
torch
.
ops
.
vllm
.
lora_expand_fp8
except
AttributeError
:
lora_expand_fp8
=
_lora_expand_fp8
vllm/lora/ops/triton_ops/lora_shrink_fp8_op.py
0 → 100644
View file @
b3ce711b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import
torch
from
vllm.lora.ops.triton_ops.fp8_kernel_utils
import
do_shrink_kernel_fp8
from
vllm.lora.ops.triton_ops.utils
import
_get_lora_a_ptr
,
get_lora_op_configs
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
direct_register_custom_op
_SHRINK_LORA_SCALE_PTR_DICT
:
dict
[
tuple
[
int
,
...],
tuple
]
=
{}
def
_get_shrink_lora_scale_ptr
(
lora_scale_weights
:
list
[
torch
.
Tensor
],
device
:
torch
.
device
):
"""
`_SHRINK_LORA_SCALE_PTR_DICT` collects the required information during
`profile_run`. After this, it remains constant and subsequent usage is
through LUT.
Returns a tuple of (scale_ptr_tensor, l_stride, n_stride, k_stride).
Supports scale tensors of varying dimensionality:
- 1D: (lora_num,) — tensor-wise quantization
- 2D: (lora_num, N) — per-channel quantization
- 3D: (lora_num, N, K) — block-wise quantization
- 4D: (lora_num, 1, N, K) — block-wise with extra dim (squeezed to 3D)
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key
=
tuple
(
lora_weight
.
data_ptr
()
for
lora_weight
in
lora_scale_weights
)
if
values
:
=
_SHRINK_LORA_SCALE_PTR_DICT
.
get
(
key
):
return
values
tensor_ptrs
=
[]
scale_l_strides
=
[]
scale_n_strides
=
[]
scale_k_strides
=
[]
for
lora_scale_weight
in
lora_scale_weights
:
if
lora_scale_weight
.
ndim
==
4
:
# shape:(lora_num,1,size,rank)
assert
lora_scale_weight
.
size
(
1
)
==
1
lora_scale_weight
=
lora_scale_weight
.
squeeze
(
dim
=
1
)
assert
1
<=
lora_scale_weight
.
ndim
<=
3
assert
lora_scale_weight
.
is_contiguous
()
tensor_ptrs
.
append
(
lora_scale_weight
.
data_ptr
())
scale_l_strides
.
append
(
lora_scale_weight
.
stride
(
0
)
if
lora_scale_weight
.
ndim
>
0
else
0
)
scale_n_strides
.
append
(
lora_scale_weight
.
stride
(
-
2
)
if
lora_scale_weight
.
ndim
>
2
else
(
lora_scale_weight
.
stride
(
-
1
)
if
lora_scale_weight
.
ndim
>
1
else
1
)
)
scale_k_strides
.
append
(
lora_scale_weight
.
stride
(
-
1
)
if
lora_scale_weight
.
ndim
>
2
else
0
)
if
len
(
lora_scale_weights
)
>
1
:
scale_ptr_tensor
=
torch
.
tensor
(
tensor_ptrs
,
device
=
device
,
dtype
=
torch
.
uint64
)
else
:
scale_ptr_tensor
=
lora_scale_weights
[
0
]
if
(
len
(
set
(
scale_l_strides
))
>
1
or
len
(
set
(
scale_n_strides
))
>
1
or
len
(
set
(
scale_k_strides
))
>
1
):
raise
ValueError
(
"All LoRA scale weights must have the same stride."
)
_SHRINK_LORA_SCALE_PTR_DICT
[
key
]
=
(
scale_ptr_tensor
,
scale_l_strides
[
0
],
scale_n_strides
[
0
],
scale_k_strides
[
0
],
)
return
_SHRINK_LORA_SCALE_PTR_DICT
.
get
(
key
)
@
triton
.
jit
def
_lora_shrink_kernel_fp8
(
input_ptr
,
lora_ptr
,
out_ptr
,
a_scale_ptr
,
b_scale_ptr
,
M
,
N
,
K
,
token_indices_sorted_by_lora_ids
,
num_tokens_per_lora
,
lora_token_start_loc
,
lora_ids
,
scaling
,
input_d0_stride
,
input_d1_stride
,
lora_d0_stride
,
lora_d1_stride
,
lora_d2_stride
,
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
output_d0_stride
,
output_d1_stride
,
output_d2_stride
,
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
## should always be false in shrink kernel
use_fp8_w8a8
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
launch_pdl
:
tl
.
constexpr
,
):
cta_n_num
=
tl
.
cdiv
(
N
,
BLOCK_N
)
cta_m_num
=
tl
.
cdiv
(
M
,
BLOCK_M
)
pid_sk_m_n
=
tl
.
program_id
(
axis
=
0
)
pid_sk
=
pid_sk_m_n
%
SPLIT_K
pid_m_n
=
pid_sk_m_n
//
SPLIT_K
num_pid_in_group
=
GROUP_SIZE_M
*
cta_n_num
group_id
=
pid_m_n
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
cta_m_num
-
first_pid_m
,
GROUP_SIZE_M
)
# Column-major ordering within groups for better cache reuse
pid_m
=
first_pid_m
+
((
pid_m_n
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid_m_n
%
num_pid_in_group
)
//
group_size_m
slice_id
=
tl
.
program_id
(
axis
=
1
)
lora_idx
=
tl
.
program_id
(
axis
=
2
)
lora_id
=
tl
.
load
(
lora_ids
+
lora_idx
)
if
lora_id
==
-
1
:
# Early exit for the no-lora case.
return
lora_m_size
=
tl
.
load
(
num_tokens_per_lora
+
lora_idx
)
cta_m_offset
=
pid_m
*
BLOCK_M
if
cta_m_offset
>=
lora_m_size
:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len
=
min
(
BLOCK_M
,
lora_m_size
-
cta_m_offset
)
# Identify all rows that this CTA should process.
lora_m_indices_start
=
tl
.
load
(
lora_token_start_loc
+
lora_idx
)
cta_lora_seq_indices
=
(
token_indices_sorted_by_lora_ids
+
lora_m_indices_start
+
cta_m_offset
)
# Load all relevant row indices.
offset_m
=
tl
.
arange
(
0
,
BLOCK_M
)
%
cta_m_len
ram
=
tl
.
load
(
cta_lora_seq_indices
+
offset_m
)
do_shrink_kernel_fp8
(
pid_n
,
pid_sk
,
slice_id
,
lora_id
,
input_ptr
,
lora_ptr
,
out_ptr
,
a_scale_ptr
,
b_scale_ptr
,
N
,
K
,
cta_m_len
,
ram
,
# array identifying the rows of Input ptr to operate on
# input strides
input_d0_stride
,
input_d1_stride
,
# lora strides
lora_d0_stride
,
lora_d1_stride
,
lora_d2_stride
,
# scale strides
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
# output strides
output_d0_stride
,
output_d1_stride
,
output_d2_stride
,
scaling
,
# block size for block-wise quantization
group_n
,
group_k
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
SPLIT_K
,
SLICE_NUM
,
USE_GDC
,
use_fp8_w8a8
,
per_channel_quant
,
launch_pdl
,
)
@
torch
.
inference_mode
()
def
_lora_shrink_fp8
(
inputs
:
torch
.
Tensor
,
# shape [num_tokens, hidden_size] - FP8 or FP16/BF16
lora_a_weights
:
list
[
torch
.
Tensor
],
# shape [num_loras, lora_rank, hidden_size] - FP8 or FP16/BF16
output_tensor
:
torch
.
Tensor
,
# shape [num_slices, num_tokens, lora_rank]
token_lora_mapping
:
torch
.
Tensor
,
# shape [num_tokens]
token_indices_sorted_by_lora_ids
:
torch
.
Tensor
,
# shape [num_tokens]
num_tokens_per_lora
:
torch
.
Tensor
,
# shape [max-loras + 1]
lora_token_start_loc
:
torch
.
Tensor
,
# shape [max-loras + 2]
lora_ids
:
torch
.
Tensor
,
# shape [max-loras + 1]
no_lora_flag_cpu
:
torch
.
Tensor
,
# shape [1]
num_active_loras
:
int
,
# number of active LoRAs (unused here, for API compat)
scaling
:
float
,
b_scale
:
list
[
torch
.
Tensor
],
# LoRA weight scale per slice
a_scale
:
torch
.
Tensor
|
None
=
None
,
# Activation scale - per-token or block-wise
group_k
:
int
=
0
,
# Block size for K in block-wise quantization (0 = tensor-wise)
group_n
:
int
=
0
,
# Block size for N in block-wise quantization
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
)
->
None
:
"""
Args:
inputs: FP8 or FP16/BF16 input tensor [num_tokens, hidden_size]
lora_a_weights: List of FP8 or FP16/BF16 LoRA A weights per slice
output_tensor: Output tensor (FP16/BF16/FP32)
token_lora_mapping: Token to LoRA ID mapping
token_indices_sorted_by_lora_ids: Sorted token indices
num_tokens_per_lora: Number of tokens per LoRA
lora_token_start_loc: Start location for each LoRA's tokens
lora_ids: LoRA IDs to process
scaling: LoRA scaling factor
a_scale: Activation quantization scales
b_scale: Weight quantization scales per slice
group_k: Block size for K dimension quantization
group_n: Block size for N dimension quantization
use_fp8_w8a8: Whether to use FP8 weights and activations
per_channel_quant: Whether to use per-channel quantization
"""
assert
no_lora_flag_cpu
.
numel
()
==
1
if
no_lora_flag_cpu
.
item
():
# None of the inputs require LoRA.
return
assert
inputs
.
size
(
1
)
==
lora_a_weights
[
0
].
size
(
-
1
)
assert
inputs
.
is_contiguous
()
assert
output_tensor
.
is_contiguous
()
# metadata sanity check
M
=
inputs
.
size
(
0
)
assert
token_lora_mapping
.
size
(
0
)
==
M
assert
token_lora_mapping
.
size
(
0
)
==
token_indices_sorted_by_lora_ids
.
size
(
0
)
assert
lora_ids
.
size
(
0
)
==
num_tokens_per_lora
.
size
(
0
)
assert
lora_token_start_loc
.
size
(
0
)
==
lora_ids
.
size
(
0
)
+
1
output_tensor
.
zero_
()
# Get LoRA weight pointers
(
lora_ptr_tensor
,
lora_strides_d0
,
lora_strides_d1
,
lora_strides_d2
)
=
(
_get_lora_a_ptr
(
lora_a_weights
,
inputs
.
device
)
)
# Get scale pointers if using FP8
if
use_fp8_w8a8
:
assert
a_scale
is
not
None
,
"a_scale required for FP8 w8a8"
assert
b_scale
is
not
None
,
"b_scale required for FP8"
b_scale_ptr_tensor
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
=
(
_get_shrink_lora_scale_ptr
(
b_scale
,
inputs
.
device
)
)
a_scale_ptr
=
(
a_scale
if
a_scale
is
not
None
else
torch
.
tensor
(
1.0
,
device
=
inputs
.
device
)
)
else
:
b_scale_ptr_tensor
=
torch
.
tensor
(
0
,
device
=
inputs
.
device
)
b_scale_l_stride
=
0
b_scale_n_stride
=
0
b_scale_k_stride
=
0
a_scale_ptr
=
torch
.
tensor
(
0
,
device
=
inputs
.
device
)
N
,
K
=
lora_a_weights
[
0
].
shape
[
-
2
:]
# K=hidden_size, N=rank
NUM_SLICES
=
len
(
lora_a_weights
)
MAX_LORAS
=
lora_ids
.
size
(
0
)
# Triton kernel configs
kernel_config
=
get_lora_op_configs
(
"shrink"
,
max_loras
=
MAX_LORAS
,
batch
=
M
,
hidden_size
=
K
,
rank
=
N
,
num_slices
=
NUM_SLICES
,
)
BLOCK_M
=
kernel_config
[
"block_m"
]
BLOCK_N
=
kernel_config
[
"block_n"
]
BLOCK_K
=
kernel_config
[
"block_k"
]
SPLIT_K
=
kernel_config
[
"split_k"
]
NUM_WARPS
=
kernel_config
[
"num_warps"
]
NUM_STAGES
=
kernel_config
[
"num_stages"
]
NUM_CTAS
=
kernel_config
[
"num_ctas"
]
GROUP_SIZE_M
=
kernel_config
.
get
(
"group_size_m"
,
8
)
assert
BLOCK_K
is
not
None
and
SPLIT_K
is
not
None
EVEN_K
=
K
%
(
BLOCK_K
*
SPLIT_K
)
==
0
# Grid configuration with column-major ordering support
grid
=
(
SPLIT_K
*
triton
.
cdiv
(
M
,
BLOCK_M
)
*
triton
.
cdiv
(
N
,
BLOCK_N
),
NUM_SLICES
,
num_active_loras
,
)
# Determine scale strides
if
use_fp8_w8a8
:
if
a_scale
is
not
None
and
a_scale
.
ndim
==
2
:
a_scale_m_stride
=
a_scale
.
stride
(
0
)
a_scale_k_stride
=
a_scale
.
stride
(
1
)
else
:
a_scale_m_stride
=
0
a_scale_k_stride
=
0
else
:
a_scale_m_stride
=
0
a_scale_k_stride
=
0
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
use_gdc
=
False
# supports_pdl(inputs.device)
_lora_shrink_kernel_fp8
[
grid
](
inputs
,
lora_ptr_tensor
,
output_tensor
,
a_scale_ptr
,
b_scale_ptr_tensor
,
M
,
N
,
K
,
token_indices_sorted_by_lora_ids
,
num_tokens_per_lora
,
lora_token_start_loc
,
lora_ids
,
scaling
,
inputs
.
stride
(
0
),
inputs
.
stride
(
1
),
lora_strides_d0
,
lora_strides_d1
,
lora_strides_d2
,
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
output_tensor
.
stride
(
2
),
group_n
,
group_k
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
SPLIT_K
,
GROUP_SIZE_M
,
NUM_SLICES
,
use_gdc
,
use_fp8_w8a8
,
per_channel_quant
,
use_gdc
,
num_warps
=
NUM_WARPS
,
num_ctas
=
NUM_CTAS
,
num_stages
=
NUM_STAGES
,
)
return
def
_lora_shrink_fp8_fake
(
inputs
:
torch
.
Tensor
,
lora_a_weights
:
list
[
torch
.
Tensor
],
output_tensor
:
torch
.
Tensor
,
token_lora_mapping
:
torch
.
Tensor
,
token_indices_sorted_by_lora_ids
:
torch
.
Tensor
,
num_tokens_per_lora
:
torch
.
Tensor
,
lora_token_start_loc
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
no_lora_flag_cpu
:
torch
.
Tensor
,
num_active_loras
:
int
,
scaling
:
float
,
b_scale
:
list
[
torch
.
Tensor
],
# LoRA weight scale per slice
a_scale
:
torch
.
Tensor
|
None
=
None
,
# Activation scale - per-token or block-wise
group_k
:
int
=
0
,
# Block size for K in block-wise quantization (0 = tensor-wise)
group_n
:
int
=
0
,
# Block size for N in block-wise quantization
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
)
->
None
:
return
try
:
direct_register_custom_op
(
op_name
=
"lora_shrink_fp8"
,
op_func
=
_lora_shrink_fp8
,
mutates_args
=
[
"output_tensor"
],
fake_impl
=
_lora_shrink_fp8_fake
,
)
lora_shrink_fp8
=
torch
.
ops
.
vllm
.
lora_shrink_fp8
except
AttributeError
:
lora_shrink_fp8
=
_lora_shrink_fp8
vllm/lora/ops/triton_ops/utils.py
View file @
b3ce711b
...
@@ -252,7 +252,7 @@ def get_lora_op_configs(
...
@@ -252,7 +252,7 @@ def get_lora_op_configs(
default
=
{
default
=
{
"block_m"
:
64
,
"block_m"
:
64
,
"block_n"
:
64
if
num_slices
>
1
else
128
,
"block_n"
:
64
if
num_slices
>
1
else
128
,
"block_k"
:
16
,
"block_k"
:
32
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_ctas"
:
1
,
"num_ctas"
:
1
,
"num_stages"
:
2
,
"num_stages"
:
2
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment