Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b3ce711b
Unverified
Commit
b3ce711b
authored
Mar 13, 2026
by
yugong333
Committed by
GitHub
Mar 13, 2026
Browse files
Fp8 lora dense kernel (#35242)
Signed-off-by:
Yu Gong
<
yu3.gong@gmail.com
>
parent
abf61aaa
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
2439 additions
and
1 deletion
+2439
-1
tests/lora/test_punica_ops_fp8.py
tests/lora/test_punica_ops_fp8.py
+999
-0
vllm/lora/ops/triton_ops/__init__.py
vllm/lora/ops/triton_ops/__init__.py
+4
-0
vllm/lora/ops/triton_ops/fp8_kernel_utils.py
vllm/lora/ops/triton_ops/fp8_kernel_utils.py
+603
-0
vllm/lora/ops/triton_ops/lora_expand_fp8_op.py
vllm/lora/ops/triton_ops/lora_expand_fp8_op.py
+403
-0
vllm/lora/ops/triton_ops/lora_shrink_fp8_op.py
vllm/lora/ops/triton_ops/lora_shrink_fp8_op.py
+429
-0
vllm/lora/ops/triton_ops/utils.py
vllm/lora/ops/triton_ops/utils.py
+1
-1
No files found.
tests/lora/test_punica_ops_fp8.py
0 → 100644
View file @
b3ce711b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""FP8 accuracy tests for LoRA shrink and expand kernels.
Tests the FP8 kernels by:
1. Quantizing bf16 inputs/weights to FP8
2. Dequantizing them back to bf16
3. Running the bf16 reference (sgmv_shrink/sgmv_expand) with dequantized values
4. Comparing FP8 kernel output against this dequantized reference
This isolates kernel correctness from quantization precision loss,
allowing much tighter tolerances than comparing against the original bf16.
"""
import
math
from
threading
import
Lock
import
pytest
import
torch
import
vllm.lora.ops.torch_ops
as
torch_ops
import
vllm.lora.ops.triton_ops
as
triton_ops
from
vllm.lora.ops.triton_ops
import
LoRAKernelMeta
from
vllm.lora.ops.triton_ops.lora_expand_fp8_op
import
(
_EXPAND_LORA_SCALE_PTR_DICT
,
)
from
vllm.lora.ops.triton_ops.lora_shrink_fp8_op
import
(
_SHRINK_LORA_SCALE_PTR_DICT
,
)
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.utils.torch_utils
import
set_random_seed
DEVICES
=
[
f
"cuda:
{
0
}
"
]
SEED
=
[
0
]
_dict_lock
=
Lock
()
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_device
(
reset_default_device
):
pass
# ============================================================================
# Reference implementations (bf16 baseline)
# ============================================================================
def
sgmv_shrink_for_nslices
(
nslices
,
inputs_tensor
,
lora_weights_lst
,
out_tensor
,
b_seq_start_loc
,
seq_len_tensor
,
prompt_lora_mapping
,
batches
,
max_seq_length
,
num_tokens
,
scaling
,
):
"""Wrapper around torch_ops.sgmv_shrink that handles any nslices."""
for
index
in
range
(
nslices
):
torch_ops
.
sgmv_shrink
(
inputs_tensor
,
lora_weights_lst
[
index
],
out_tensor
[
index
],
b_seq_start_loc
,
seq_len_tensor
,
prompt_lora_mapping
,
batches
,
max_seq_length
,
num_tokens
,
scaling
,
)
def
sgmv_expand_for_nslices
(
nslices
,
hidden_size
,
inputs_tensor
,
lora_weights_lst
,
out_tensor
,
b_seq_start_loc
,
seq_len_tensor
,
prompt_lora_mapping
,
batches
,
max_seq_length
,
num_tokens
,
add_inputs
,
):
"""Wrapper around torch_ops.sgmv_expand that handles any nslices."""
if
nslices
==
1
:
torch_ops
.
sgmv_expand
(
inputs_tensor
[
0
],
lora_weights_lst
[
0
],
out_tensor
,
b_seq_start_loc
,
seq_len_tensor
,
prompt_lora_mapping
,
batches
,
max_seq_length
,
num_tokens
,
add_inputs
=
add_inputs
,
)
else
:
slice_offset
=
0
for
index
in
range
(
nslices
):
torch_ops
.
sgmv_expand_slice
(
inputs_tensor
[
index
],
lora_weights_lst
[
index
],
out_tensor
,
b_seq_start_loc
,
seq_len_tensor
,
prompt_lora_mapping
,
batches
,
max_seq_length
,
num_tokens
,
slice_offset
,
hidden_size
,
add_inputs
=
add_inputs
,
)
slice_offset
+=
hidden_size
# ============================================================================
# FP8 Quantization Helpers
# ============================================================================
FP8_DTYPE
=
torch
.
float8_e4m3fn
FP8_MAX
=
torch
.
finfo
(
FP8_DTYPE
).
max
FP8_MIN
=
torch
.
finfo
(
FP8_DTYPE
).
min
def
quantize_to_fp8_per_tensor
(
tensor
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize a tensor to FP8 with per-tensor scaling."""
amax
=
tensor
.
abs
().
float
().
max
().
clamp
(
min
=
1e-12
)
scale
=
(
amax
/
FP8_MAX
).
to
(
torch
.
float32
)
fp8_tensor
=
(
tensor
.
float
()
/
scale
).
clamp
(
FP8_MIN
,
FP8_MAX
).
to
(
FP8_DTYPE
)
return
fp8_tensor
,
scale
.
reshape
(
1
)
def
quantize_to_fp8_per_channel
(
tensor
:
torch
.
Tensor
,
channel_dim
:
int
=
0
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize a tensor to FP8 with per-channel scaling.
For shrink lora_a weights of shape (num_loras, rank, hidden_size):
channel_dim=1 gives per-rank scaling -> scale shape (num_loras, rank)
For expand lora_b weights of shape (num_loras, hidden_size, rank):
channel_dim=1 gives per-hidden scaling -> scale shape (num_loras, hidden_size)
"""
# Compute amax along all dims except the leading dims up to channel_dim+1
reduce_dims
=
list
(
range
(
channel_dim
+
1
,
tensor
.
ndim
))
if
reduce_dims
:
amax
=
tensor
.
abs
().
float
().
amax
(
dim
=
reduce_dims
).
clamp
(
min
=
1e-12
)
else
:
amax
=
tensor
.
abs
().
float
().
clamp
(
min
=
1e-12
)
scale
=
(
amax
/
FP8_MAX
).
to
(
torch
.
float32
)
# Expand scale for broadcasting
for
_
in
reduce_dims
:
scale
=
scale
.
unsqueeze
(
-
1
)
fp8_tensor
=
(
tensor
.
float
()
/
scale
).
clamp
(
FP8_MIN
,
FP8_MAX
).
to
(
FP8_DTYPE
)
scale
=
scale
.
squeeze
()
if
scale
.
ndim
==
0
:
scale
=
scale
.
unsqueeze
(
0
)
return
fp8_tensor
,
scale
def
quantize_to_fp8_per_token
(
tensor
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize a 2D tensor to FP8 with per-token (per-row) scaling.
Input shape: (num_tokens, hidden_size)
Returns: (fp8_tensor, scale) where scale shape is (num_tokens, 1)
"""
assert
tensor
.
ndim
==
2
amax
=
tensor
.
abs
().
float
().
amax
(
dim
=
1
,
keepdim
=
True
).
clamp
(
min
=
1e-12
)
scale
=
(
amax
/
FP8_MAX
).
to
(
torch
.
float32
)
fp8_tensor
=
(
tensor
.
float
()
/
scale
).
clamp
(
FP8_MIN
,
FP8_MAX
).
to
(
FP8_DTYPE
)
return
fp8_tensor
,
scale
def
quantize_to_fp8_blockwise
(
tensor
:
torch
.
Tensor
,
group_n
:
int
,
group_k
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize a 2D or 3D tensor to FP8 with block-wise scaling.
For a 2D tensor (num_tokens, hidden_size):
Blocks of size (1, group_k) ->
scale shape (num_tokens, ceil(hidden_size/group_k))
For a 3D tensor (num_loras, N, K):
Blocks of size (group_n, group_k) ->
scale shape (num_loras, ceil(N/group_n), ceil(K/group_k))
"""
if
tensor
.
ndim
==
2
:
M
,
K
=
tensor
.
shape
n_blocks_k
=
math
.
ceil
(
K
/
group_k
)
scale
=
torch
.
zeros
(
M
,
n_blocks_k
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
fp8_tensor
=
torch
.
zeros_like
(
tensor
,
dtype
=
FP8_DTYPE
)
for
m
in
range
(
M
):
for
bk
in
range
(
n_blocks_k
):
k_start
=
bk
*
group_k
k_end
=
min
(
k_start
+
group_k
,
K
)
block
=
tensor
[
m
,
k_start
:
k_end
].
float
()
amax
=
block
.
abs
().
max
().
clamp
(
min
=
1e-12
)
s
=
(
amax
/
FP8_MAX
).
to
(
torch
.
float32
)
scale
[
m
,
bk
]
=
s
fp8_tensor
[
m
,
k_start
:
k_end
]
=
(
(
block
/
s
).
clamp
(
FP8_MIN
,
FP8_MAX
).
to
(
FP8_DTYPE
)
)
return
fp8_tensor
,
scale
elif
tensor
.
ndim
==
3
:
L
,
N
,
K
=
tensor
.
shape
n_blocks_n
=
math
.
ceil
(
N
/
group_n
)
n_blocks_k
=
math
.
ceil
(
K
/
group_k
)
scale
=
torch
.
zeros
(
L
,
n_blocks_n
,
n_blocks_k
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
fp8_tensor
=
torch
.
zeros_like
(
tensor
,
dtype
=
FP8_DTYPE
)
for
li
in
range
(
L
):
for
bn
in
range
(
n_blocks_n
):
for
bk
in
range
(
n_blocks_k
):
n_start
=
bn
*
group_n
n_end
=
min
(
n_start
+
group_n
,
N
)
k_start
=
bk
*
group_k
k_end
=
min
(
k_start
+
group_k
,
K
)
block
=
tensor
[
li
,
n_start
:
n_end
,
k_start
:
k_end
].
float
()
amax
=
block
.
abs
().
max
().
clamp
(
min
=
1e-12
)
s
=
(
amax
/
FP8_MAX
).
to
(
torch
.
float32
)
scale
[
li
,
bn
,
bk
]
=
s
fp8_tensor
[
li
,
n_start
:
n_end
,
k_start
:
k_end
]
=
(
(
block
/
s
).
clamp
(
FP8_MIN
,
FP8_MAX
).
to
(
FP8_DTYPE
)
)
return
fp8_tensor
,
scale
else
:
raise
ValueError
(
f
"Unsupported tensor ndim:
{
tensor
.
ndim
}
"
)
# ============================================================================
# FP8 Dequantization Helpers
# ============================================================================
def
dequantize_fp8_per_tensor
(
fp8_tensor
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
output_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
torch
.
Tensor
:
"""Dequantize FP8 tensor with per-tensor scale back to output_dtype."""
return
(
fp8_tensor
.
float
()
*
scale
.
float
()).
to
(
output_dtype
)
def
dequantize_fp8_per_channel
(
fp8_tensor
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
channel_dim
:
int
,
output_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
torch
.
Tensor
:
"""Dequantize FP8 tensor with per-channel scale back to output_dtype.
For 3D tensor (num_loras, N, K) with channel_dim=1:
scale shape is (num_loras, N), broadcast over K.
"""
expand_scale
=
scale
.
float
()
# Add trailing dims for broadcasting
for
_
in
range
(
channel_dim
+
1
,
fp8_tensor
.
ndim
):
expand_scale
=
expand_scale
.
unsqueeze
(
-
1
)
return
(
fp8_tensor
.
float
()
*
expand_scale
).
to
(
output_dtype
)
def
dequantize_fp8_per_token
(
fp8_tensor
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
output_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
torch
.
Tensor
:
"""Dequantize FP8 2D tensor with per-token scale back to output_dtype.
fp8_tensor: (num_tokens, hidden_size), scale: (num_tokens, 1)
"""
return
(
fp8_tensor
.
float
()
*
scale
.
float
()).
to
(
output_dtype
)
def
dequantize_fp8_blockwise
(
fp8_tensor
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
group_n
:
int
,
group_k
:
int
,
output_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
torch
.
Tensor
:
"""Dequantize FP8 tensor with block-wise scale back to output_dtype."""
if
fp8_tensor
.
ndim
==
2
:
M
,
K
=
fp8_tensor
.
shape
out
=
torch
.
zeros
(
M
,
K
,
dtype
=
output_dtype
,
device
=
fp8_tensor
.
device
)
n_blocks_k
=
math
.
ceil
(
K
/
group_k
)
for
m
in
range
(
M
):
for
bk
in
range
(
n_blocks_k
):
k_start
=
bk
*
group_k
k_end
=
min
(
k_start
+
group_k
,
K
)
out
[
m
,
k_start
:
k_end
]
=
(
fp8_tensor
[
m
,
k_start
:
k_end
].
float
()
*
scale
[
m
,
bk
].
float
()
).
to
(
output_dtype
)
return
out
elif
fp8_tensor
.
ndim
==
3
:
L
,
N
,
K
=
fp8_tensor
.
shape
out
=
torch
.
zeros
(
L
,
N
,
K
,
dtype
=
output_dtype
,
device
=
fp8_tensor
.
device
)
n_blocks_n
=
math
.
ceil
(
N
/
group_n
)
n_blocks_k
=
math
.
ceil
(
K
/
group_k
)
for
l_idx
in
range
(
L
):
for
bn
in
range
(
n_blocks_n
):
for
bk
in
range
(
n_blocks_k
):
n_start
=
bn
*
group_n
n_end
=
min
(
n_start
+
group_n
,
N
)
k_start
=
bk
*
group_k
k_end
=
min
(
k_start
+
group_k
,
K
)
out
[
l_idx
,
n_start
:
n_end
,
k_start
:
k_end
]
=
(
fp8_tensor
[
l_idx
,
n_start
:
n_end
,
k_start
:
k_end
].
float
()
*
scale
[
l_idx
,
bn
,
bk
].
float
()
).
to
(
output_dtype
)
return
out
else
:
raise
ValueError
(
f
"Unsupported tensor ndim:
{
fp8_tensor
.
ndim
}
"
)
# ============================================================================
# FP8 Data Generation
# ============================================================================
def
generate_fp8_shrink_data
(
batches
:
int
,
hidden_size
:
int
,
num_loras
:
int
,
rank
:
int
,
seq_length
:
int
,
nslices
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
quant_mode
:
str
,
# "per_tensor", "per_channel", "blockwise"
group_k
:
int
=
128
,
group_n
:
int
=
128
,
):
"""Generate test data for FP8 shrink kernel.
Shrink: output = input @ lora_a^T * scaling
input: (num_tokens, hidden_size) -> quantized to FP8
lora_a: (num_loras, rank, hidden_size) -> quantized to FP8
Returns bf16 reference tensors, FP8 quantized tensors with scales,
and dequantized bf16 tensors for accurate reference computation.
"""
seq_len_tensor
=
torch
.
randint
(
seq_length
,
seq_length
+
1
,
(
batches
,)).
to
(
device
)
b_seq_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
seq_len_tensor
[:
-
1
].
tolist
(),
dtype
=
torch
.
long
),
dim
=
0
,
).
to
(
device
)
total_tokens
=
seq_len_tensor
.
sum
().
item
()
# Generate bf16 reference data
inputs_bf16
=
torch
.
randn
(
total_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
lora_a_weights_bf16
=
[]
for
_
in
range
(
nslices
):
lora_a_weights_bf16
.
append
(
torch
.
randn
(
num_loras
,
rank
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
)
# Quantize inputs to FP8 and dequantize back for reference
if
quant_mode
==
"blockwise"
:
inputs_fp8
,
a_scale
=
quantize_to_fp8_blockwise
(
inputs_bf16
,
group_n
=
1
,
group_k
=
group_k
)
inputs_dequant
=
dequantize_fp8_blockwise
(
inputs_fp8
,
a_scale
,
group_n
=
1
,
group_k
=
group_k
,
output_dtype
=
dtype
,
)
elif
quant_mode
==
"per_tensor"
:
# Per-tensor: kernel loads a single scalar from a_scale_ptr
inputs_fp8
,
a_scale
=
quantize_to_fp8_per_tensor
(
inputs_bf16
)
inputs_dequant
=
dequantize_fp8_per_tensor
(
inputs_fp8
,
a_scale
,
output_dtype
=
dtype
,
)
else
:
# per_channel: kernel loads per-token a_scale via ram indexing
inputs_fp8
,
a_scale
=
quantize_to_fp8_per_token
(
inputs_bf16
)
inputs_dequant
=
dequantize_fp8_per_token
(
inputs_fp8
,
a_scale
,
output_dtype
=
dtype
,
)
# Quantize lora_a weights to FP8 and dequantize back for reference
b_scales
=
[]
lora_a_weights_fp8
=
[]
lora_a_weights_dequant
=
[]
for
w
in
lora_a_weights_bf16
:
if
quant_mode
==
"per_tensor"
:
w_fp8
,
w_scale
=
quantize_to_fp8_per_tensor
(
w
)
w_dequant
=
dequantize_fp8_per_tensor
(
w_fp8
,
w_scale
,
output_dtype
=
dtype
)
# Scale shape: (1,) -> need (num_loras,) for the kernel
w_scale
=
w_scale
.
expand
(
num_loras
).
contiguous
()
lora_a_weights_fp8
.
append
(
w_fp8
)
b_scales
.
append
(
w_scale
)
lora_a_weights_dequant
.
append
(
w_dequant
)
elif
quant_mode
==
"per_channel"
:
# Per-channel along rank dim: scale shape (num_loras, rank)
w_fp8
,
w_scale
=
quantize_to_fp8_per_channel
(
w
,
channel_dim
=
1
)
w_dequant
=
dequantize_fp8_per_channel
(
w_fp8
,
w_scale
,
channel_dim
=
1
,
output_dtype
=
dtype
,
)
lora_a_weights_fp8
.
append
(
w_fp8
)
b_scales
.
append
(
w_scale
)
lora_a_weights_dequant
.
append
(
w_dequant
)
elif
quant_mode
==
"blockwise"
:
w_fp8
,
w_scale
=
quantize_to_fp8_blockwise
(
w
,
group_n
=
group_n
,
group_k
=
group_k
)
w_dequant
=
dequantize_fp8_blockwise
(
w_fp8
,
w_scale
,
group_n
=
group_n
,
group_k
=
group_k
,
output_dtype
=
dtype
,
)
lora_a_weights_fp8
.
append
(
w_fp8
)
b_scales
.
append
(
w_scale
)
lora_a_weights_dequant
.
append
(
w_dequant
)
# Output tensor (float32 for shrink)
out_tensor
=
torch
.
zeros
(
nslices
,
total_tokens
,
rank
,
dtype
=
torch
.
float32
,
device
=
device
)
ref_out_tensor
=
out_tensor
.
clone
()
# Token-to-lora mapping
lora_indices_tensor
=
torch
.
randint
(
0
,
max
(
num_loras
-
1
,
1
),
(
batches
,)).
to
(
device
)
token_lora_mapping
=
torch
.
zeros
(
total_tokens
,
dtype
=
torch
.
long
,
device
=
device
)
current_offset
=
0
for
b_id
in
range
(
batches
):
lora_index
=
lora_indices_tensor
[
b_id
]
sl
=
seq_len_tensor
[
b_id
].
item
()
token_lora_mapping
[
current_offset
:
current_offset
+
sl
]
=
lora_index
current_offset
+=
sl
return
{
"inputs_bf16"
:
inputs_bf16
,
"inputs_fp8"
:
inputs_fp8
,
"inputs_dequant"
:
inputs_dequant
,
"lora_a_bf16"
:
lora_a_weights_bf16
,
"lora_a_fp8"
:
lora_a_weights_fp8
,
"lora_a_dequant"
:
lora_a_weights_dequant
,
"a_scale"
:
a_scale
,
"b_scales"
:
b_scales
,
"out_tensor"
:
out_tensor
,
"ref_out_tensor"
:
ref_out_tensor
,
"token_lora_mapping"
:
token_lora_mapping
,
"seq_len_tensor"
:
seq_len_tensor
,
"b_seq_start_loc"
:
b_seq_start_loc
,
"lora_indices_tensor"
:
lora_indices_tensor
,
"total_tokens"
:
total_tokens
,
}
def
generate_fp8_expand_data
(
batches
:
int
,
hidden_size
:
int
,
num_loras
:
int
,
rank
:
int
,
seq_length
:
int
,
nslices
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
quant_mode
:
str
,
# "per_tensor", "per_channel", "blockwise"
group_k
:
int
=
128
,
group_n
:
int
=
128
,
):
"""Generate test data for FP8 expand kernel (w8a8).
Expand: output += input @ lora_b^T
input: (nslices, num_tokens, rank) -> quantized to FP8 (activations)
lora_b: (num_loras, hidden_size, rank) -> quantized to FP8 (weights)
In w8a8 mode, both activations and weights are FP8.
Returns bf16 reference tensors, FP8 quantized tensors with scales,
and dequantized bf16 tensors for accurate reference computation.
"""
seq_len_tensor
=
torch
.
randint
(
seq_length
,
seq_length
+
1
,
(
batches
,)).
to
(
device
)
b_seq_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
seq_len_tensor
[:
-
1
].
tolist
(),
dtype
=
torch
.
long
),
dim
=
0
,
).
to
(
device
)
total_tokens
=
seq_len_tensor
.
sum
().
item
()
# Generate bf16 input (shrink output) and quantize to FP8
inputs_bf16
=
torch
.
randn
(
nslices
,
total_tokens
,
rank
,
dtype
=
dtype
,
device
=
device
)
# Quantize input to FP8 and dequantize back for reference
inputs_2d_all
=
inputs_bf16
.
reshape
(
-
1
,
rank
)
if
quant_mode
==
"blockwise"
:
# For blockwise, the kernel indexes a_scale by token id (0..total_tokens-1)
# shared across slices. Compute shared scale across slices, then quantize.
# First compute per-token-per-block scale across all slices
n_blocks_k
=
math
.
ceil
(
rank
/
group_k
)
a_scale
=
torch
.
zeros
(
total_tokens
,
n_blocks_k
,
dtype
=
torch
.
float32
,
device
=
device
)
for
m
in
range
(
total_tokens
):
for
bk
in
range
(
n_blocks_k
):
k_start
=
bk
*
group_k
k_end
=
min
(
k_start
+
group_k
,
rank
)
# Max across all slices for this token and block
block_amax
=
torch
.
tensor
(
0.0
,
device
=
device
)
for
s
in
range
(
nslices
):
block
=
inputs_bf16
[
s
,
m
,
k_start
:
k_end
].
float
()
block_amax
=
torch
.
max
(
block_amax
,
block
.
abs
().
max
().
clamp
(
min
=
1e-12
)
)
a_scale
[
m
,
bk
]
=
(
block_amax
/
FP8_MAX
).
to
(
torch
.
float32
)
# Quantize all slices with the shared scale
inputs_fp8_list
=
[]
inputs_dequant_list
=
[]
for
s
in
range
(
nslices
):
slice_2d
=
inputs_bf16
[
s
]
# (total_tokens, rank)
fp8_slice
=
torch
.
zeros_like
(
slice_2d
,
dtype
=
FP8_DTYPE
)
dequant_slice
=
torch
.
zeros_like
(
slice_2d
)
for
m
in
range
(
total_tokens
):
for
bk
in
range
(
n_blocks_k
):
k_start
=
bk
*
group_k
k_end
=
min
(
k_start
+
group_k
,
rank
)
block
=
slice_2d
[
m
,
k_start
:
k_end
].
float
()
s_val
=
a_scale
[
m
,
bk
]
fp8_slice
[
m
,
k_start
:
k_end
]
=
(
(
block
/
s_val
).
clamp
(
FP8_MIN
,
FP8_MAX
).
to
(
FP8_DTYPE
)
)
dequant_slice
[
m
,
k_start
:
k_end
]
=
(
fp8_slice
[
m
,
k_start
:
k_end
].
float
()
*
s_val
.
float
()
).
to
(
dtype
)
inputs_fp8_list
.
append
(
fp8_slice
)
inputs_dequant_list
.
append
(
dequant_slice
)
inputs_fp8
=
torch
.
stack
(
inputs_fp8_list
,
dim
=
0
)
inputs_dequant
=
torch
.
stack
(
inputs_dequant_list
,
dim
=
0
)
elif
quant_mode
==
"per_tensor"
:
# Per-tensor: kernel loads a single scalar from a_scale_ptr
inputs_fp8_2d
,
a_scale
=
quantize_to_fp8_per_tensor
(
inputs_2d_all
)
inputs_dequant_2d
=
dequantize_fp8_per_tensor
(
inputs_fp8_2d
,
a_scale
,
output_dtype
=
dtype
,
)
inputs_fp8
=
inputs_fp8_2d
.
reshape
(
nslices
,
total_tokens
,
rank
)
inputs_dequant
=
inputs_dequant_2d
.
reshape
(
nslices
,
total_tokens
,
rank
)
else
:
# per_channel: kernel loads per-token a_scale via ram indexing.
# The kernel uses the same a_scale for all slices (indexed by token
# id 0..total_tokens-1), so we compute a shared per-token scale
# across all slices, then quantize each slice with that shared scale.
per_slice_views
=
[
inputs_bf16
[
s
]
for
s
in
range
(
nslices
)]
# (nslices, total_tokens, rank) -> max across slices per token
stacked
=
torch
.
stack
(
per_slice_views
,
dim
=
0
)
# (nslices, tokens, rank)
amax
=
stacked
.
abs
().
float
().
amax
(
dim
=
(
0
,
2
),
keepdim
=
False
).
clamp
(
min
=
1e-12
)
# amax shape: (total_tokens,)
a_scale
=
(
amax
/
FP8_MAX
).
to
(
torch
.
float32
).
unsqueeze
(
1
)
# (tokens, 1)
# Quantize all slices with the shared scale
inputs_fp8_2d
=
(
(
inputs_2d_all
.
float
()
/
a_scale
.
repeat
(
nslices
,
1
))
.
clamp
(
FP8_MIN
,
FP8_MAX
)
.
to
(
FP8_DTYPE
)
)
inputs_dequant_2d
=
(
inputs_fp8_2d
.
float
()
*
a_scale
.
repeat
(
nslices
,
1
).
float
()
).
to
(
dtype
)
inputs_fp8
=
inputs_fp8_2d
.
reshape
(
nslices
,
total_tokens
,
rank
)
inputs_dequant
=
inputs_dequant_2d
.
reshape
(
nslices
,
total_tokens
,
rank
)
# Generate bf16 LoRA B weights
lora_b_weights_bf16
=
[]
for
_
in
range
(
nslices
):
lora_b_weights_bf16
.
append
(
torch
.
randn
(
num_loras
,
hidden_size
,
rank
,
dtype
=
dtype
,
device
=
device
)
)
# Quantize LoRA B weights to FP8 and dequantize back for reference
b_scales
=
[]
lora_b_weights_fp8
=
[]
lora_b_weights_dequant
=
[]
for
w
in
lora_b_weights_bf16
:
if
quant_mode
==
"per_tensor"
:
w_fp8
,
w_scale
=
quantize_to_fp8_per_tensor
(
w
)
w_dequant
=
dequantize_fp8_per_tensor
(
w_fp8
,
w_scale
,
output_dtype
=
dtype
)
w_scale
=
w_scale
.
expand
(
num_loras
).
contiguous
()
lora_b_weights_fp8
.
append
(
w_fp8
)
b_scales
.
append
(
w_scale
)
lora_b_weights_dequant
.
append
(
w_dequant
)
elif
quant_mode
==
"per_channel"
:
# Per-channel along hidden_size dim: scale (num_loras, hidden_size)
w_fp8
,
w_scale
=
quantize_to_fp8_per_channel
(
w
,
channel_dim
=
1
)
w_dequant
=
dequantize_fp8_per_channel
(
w_fp8
,
w_scale
,
channel_dim
=
1
,
output_dtype
=
dtype
,
)
lora_b_weights_fp8
.
append
(
w_fp8
)
b_scales
.
append
(
w_scale
)
lora_b_weights_dequant
.
append
(
w_dequant
)
elif
quant_mode
==
"blockwise"
:
w_fp8
,
w_scale
=
quantize_to_fp8_blockwise
(
w
,
group_n
=
group_n
,
group_k
=
group_k
)
w_dequant
=
dequantize_fp8_blockwise
(
w_fp8
,
w_scale
,
group_n
=
group_n
,
group_k
=
group_k
,
output_dtype
=
dtype
,
)
lora_b_weights_fp8
.
append
(
w_fp8
)
b_scales
.
append
(
w_scale
)
lora_b_weights_dequant
.
append
(
w_dequant
)
# Output tensor (initialized randomly for add_inputs)
out_tensor
=
torch
.
randn
(
total_tokens
,
hidden_size
*
nslices
,
dtype
=
dtype
,
device
=
device
)
ref_out_tensor
=
out_tensor
.
clone
()
# Token-to-lora mapping
lora_indices_tensor
=
torch
.
randint
(
0
,
max
(
num_loras
-
1
,
1
),
(
batches
,)).
to
(
device
)
token_lora_mapping
=
torch
.
zeros
(
total_tokens
,
dtype
=
torch
.
long
,
device
=
device
)
current_offset
=
0
for
b_id
in
range
(
batches
):
lora_index
=
lora_indices_tensor
[
b_id
]
sl
=
seq_len_tensor
[
b_id
].
item
()
token_lora_mapping
[
current_offset
:
current_offset
+
sl
]
=
lora_index
current_offset
+=
sl
return
{
"inputs_bf16"
:
inputs_bf16
,
"inputs_fp8"
:
inputs_fp8
,
"inputs_dequant"
:
inputs_dequant
,
"a_scale"
:
a_scale
,
"lora_b_bf16"
:
lora_b_weights_bf16
,
"lora_b_fp8"
:
lora_b_weights_fp8
,
"lora_b_dequant"
:
lora_b_weights_dequant
,
"b_scales"
:
b_scales
,
"out_tensor"
:
out_tensor
,
"ref_out_tensor"
:
ref_out_tensor
,
"token_lora_mapping"
:
token_lora_mapping
,
"seq_len_tensor"
:
seq_len_tensor
,
"b_seq_start_loc"
:
b_seq_start_loc
,
"lora_indices_tensor"
:
lora_indices_tensor
,
"total_tokens"
:
total_tokens
,
}
# ============================================================================
# FP8 Shrink Kernel Check
# ============================================================================
def
check_lora_shrink_fp8_kernel
(
batches
:
int
,
num_loras
:
int
,
rank
:
int
,
hidden_size
:
int
,
nslices
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
seq_length
:
int
,
scaling
:
float
,
quant_mode
:
str
,
group_k
:
int
=
128
,
group_n
:
int
=
128
,
):
"""Test FP8 shrink kernel against dequantized bf16 reference.
Instead of comparing FP8 kernel output against the original bf16 reference
(which conflates quantization error with kernel error), we:
1. Quantize bf16 inputs/weights to FP8
2. Dequantize them back to bf16
3. Run the bf16 reference (sgmv_shrink) with the dequantized values
4. Compare FP8 kernel output against this dequantized reference
This isolates kernel correctness from quantization precision loss,
allowing much tighter tolerances.
"""
data
=
generate_fp8_shrink_data
(
batches
,
hidden_size
,
num_loras
,
rank
,
seq_length
,
nslices
,
dtype
,
device
,
quant_mode
,
group_k
,
group_n
,
)
total_tokens
=
data
[
"total_tokens"
]
# Setup LoRA kernel metadata
lora_meta
=
LoRAKernelMeta
.
make
(
max_loras
=
num_loras
,
max_num_tokens
=
total_tokens
,
device
=
device
)
lora_meta
.
prepare_tensors
(
data
[
"token_lora_mapping"
])
out_tensor
=
data
[
"out_tensor"
]
# Determine quantization params for the kernel
per_channel
=
quant_mode
==
"per_channel"
gk
=
group_k
if
quant_mode
==
"blockwise"
else
0
gn
=
group_n
if
quant_mode
==
"blockwise"
else
0
with
_dict_lock
:
_LORA_A_PTR_DICT
.
clear
()
_SHRINK_LORA_SCALE_PTR_DICT
.
clear
()
triton_ops
.
lora_shrink_fp8
(
data
[
"inputs_fp8"
],
data
[
"lora_a_fp8"
],
out_tensor
,
*
lora_meta
.
meta_args
(
token_nums
=
total_tokens
,
specialize_active_lora
=
False
),
scaling
,
data
[
"b_scales"
],
a_scale
=
data
[
"a_scale"
],
group_k
=
gk
,
group_n
=
gn
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
per_channel
,
)
# Compute reference using dequantized (round-tripped) tensors.
# This means the reference sees the same quantization error as the kernel,
# so any difference is purely kernel error.
ref_out_tensor
=
data
[
"ref_out_tensor"
]
max_seq_length
=
data
[
"seq_len_tensor"
].
max
().
item
()
sgmv_shrink_for_nslices
(
nslices
,
data
[
"inputs_dequant"
],
data
[
"lora_a_dequant"
],
ref_out_tensor
,
data
[
"b_seq_start_loc"
],
data
[
"seq_len_tensor"
],
data
[
"lora_indices_tensor"
],
batches
,
max_seq_length
,
total_tokens
,
scaling
,
)
# With dequantized reference, we can use much tighter tolerances
# since we're only measuring kernel error, not quantization error.
# Blockwise accumulation order differs from the bf16 reference, so
# allow a slightly larger margin for sporadic rounding outliers.
rtol
,
atol
=
0.1
,
0.25
torch
.
testing
.
assert_close
(
out_tensor
.
to
(
dtype
),
ref_out_tensor
.
to
(
dtype
),
rtol
=
rtol
,
atol
=
atol
)
# ============================================================================
# FP8 Expand Kernel Check
# ============================================================================
def
check_lora_expand_fp8_kernel
(
batches
:
int
,
num_loras
:
int
,
rank
:
int
,
hidden_size
:
int
,
nslices
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
seq_length
:
int
,
add_inputs
:
bool
,
quant_mode
:
str
,
group_k
:
int
=
128
,
group_n
:
int
=
128
,
):
"""Test FP8 expand kernel (w8a8) against dequantized bf16 reference.
Instead of comparing FP8 kernel output against the original bf16 reference
(which conflates quantization error with kernel error), we:
1. Quantize bf16 inputs/weights to FP8
2. Dequantize them back to bf16
3. Run the bf16 reference (sgmv_expand) with the dequantized values
4. Compare FP8 kernel output against this dequantized reference
This isolates kernel correctness from quantization precision loss,
allowing much tighter tolerances.
"""
data
=
generate_fp8_expand_data
(
batches
,
hidden_size
,
num_loras
,
rank
,
seq_length
,
nslices
,
dtype
,
device
,
quant_mode
,
group_k
,
group_n
,
)
total_tokens
=
data
[
"total_tokens"
]
# Setup LoRA kernel metadata
lora_meta
=
LoRAKernelMeta
.
make
(
max_loras
=
num_loras
,
max_num_tokens
=
total_tokens
,
device
=
device
)
lora_meta
.
prepare_tensors
(
data
[
"token_lora_mapping"
])
out_tensor
=
data
[
"out_tensor"
]
# Determine quantization params for the kernel
per_channel
=
quant_mode
==
"per_channel"
gk
=
group_k
if
quant_mode
==
"blockwise"
else
0
gn
=
group_n
if
quant_mode
==
"blockwise"
else
0
with
_dict_lock
:
_LORA_B_PTR_DICT
.
clear
()
_EXPAND_LORA_SCALE_PTR_DICT
.
clear
()
triton_ops
.
lora_expand_fp8
(
data
[
"inputs_fp8"
],
data
[
"lora_b_fp8"
],
out_tensor
,
*
lora_meta
.
meta_args
(
token_nums
=
total_tokens
,
specialize_active_lora
=
False
),
data
[
"b_scales"
],
a_scale
=
data
[
"a_scale"
],
offset_start
=
0
,
add_inputs
=
add_inputs
,
group_k
=
gk
,
group_n
=
gn
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
per_channel
,
)
# Compute reference using dequantized (round-tripped) tensors.
ref_out_tensor
=
data
[
"ref_out_tensor"
]
max_seq_length
=
data
[
"seq_len_tensor"
].
max
().
item
()
sgmv_expand_for_nslices
(
nslices
,
hidden_size
,
data
[
"inputs_dequant"
],
data
[
"lora_b_dequant"
],
ref_out_tensor
,
data
[
"b_seq_start_loc"
],
data
[
"seq_len_tensor"
],
data
[
"lora_indices_tensor"
],
batches
,
max_seq_length
,
total_tokens
,
add_inputs
=
add_inputs
,
)
# With dequantized reference, we can use much tighter tolerances
# since we're only measuring kernel error, not quantization error.
rtol
,
atol
=
0.1
,
0.15
torch
.
testing
.
assert_close
(
out_tensor
,
ref_out_tensor
,
rtol
=
rtol
,
atol
=
atol
)
# ============================================================================
# FP8 Test Parameters
# ============================================================================
fp8_test_params
=
{
"hidden_sizes"
:
[
512
,
1024
,
2048
],
"batches"
:
[
1
,
4
,
16
],
"num_loras"
:
[
1
,
4
,
8
],
"max_ranks"
:
[
8
,
16
,
32
,
64
],
}
# ============================================================================
# FP8 Shrink Tests
# ============================================================================
@
pytest
.
mark
.
parametrize
(
"batches"
,
fp8_test_params
[
"batches"
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
fp8_test_params
[
"num_loras"
])
@
pytest
.
mark
.
parametrize
(
"rank"
,
fp8_test_params
[
"max_ranks"
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
fp8_test_params
[
"hidden_sizes"
])
@
pytest
.
mark
.
parametrize
(
"nslices"
,
[
1
,
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEED
)
@
pytest
.
mark
.
parametrize
(
"quant_mode"
,
[
"per_tensor"
,
"per_channel"
,
"blockwise"
])
def
test_lora_shrink_fp8
(
batches
:
int
,
num_loras
:
int
,
rank
:
int
,
hidden_size
:
int
,
nslices
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
seed
:
int
,
quant_mode
:
str
,
):
"""Test FP8 shrink kernel with per-tensor, per-channel, and block-wise
quantization, comparing against the bf16 baseline."""
torch
.
set_default_device
(
device
)
set_random_seed
(
seed
)
# For blockwise, group sizes must divide evenly or be handled by the kernel
group_k
=
128
group_n
=
128
# Adjust group sizes if they're larger than the dimensions
if
quant_mode
==
"blockwise"
:
group_k
=
min
(
group_k
,
hidden_size
)
group_n
=
min
(
group_n
,
rank
)
check_lora_shrink_fp8_kernel
(
batches
=
batches
,
num_loras
=
num_loras
,
rank
=
rank
,
hidden_size
=
hidden_size
,
nslices
=
nslices
,
dtype
=
dtype
,
device
=
device
,
seq_length
=
128
,
scaling
=
0.5
,
quant_mode
=
quant_mode
,
group_k
=
group_k
,
group_n
=
group_n
,
)
# ============================================================================
# FP8 Expand Tests
# ============================================================================
@
pytest
.
mark
.
parametrize
(
"batches"
,
fp8_test_params
[
"batches"
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
fp8_test_params
[
"num_loras"
])
@
pytest
.
mark
.
parametrize
(
"rank"
,
fp8_test_params
[
"max_ranks"
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
fp8_test_params
[
"hidden_sizes"
])
@
pytest
.
mark
.
parametrize
(
"nslices"
,
[
1
,
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEED
)
@
pytest
.
mark
.
parametrize
(
"quant_mode"
,
[
"per_tensor"
,
"per_channel"
,
"blockwise"
])
def
test_lora_expand_fp8
(
batches
:
int
,
num_loras
:
int
,
rank
:
int
,
hidden_size
:
int
,
nslices
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
seed
:
int
,
quant_mode
:
str
,
):
"""Test FP8 expand kernel with per-tensor, per-channel, and block-wise
quantization, comparing against the bf16 baseline."""
torch
.
set_default_device
(
device
)
set_random_seed
(
seed
)
group_k
=
128
group_n
=
128
# Adjust group sizes if they're larger than the dimensions
if
quant_mode
==
"blockwise"
:
group_k
=
min
(
group_k
,
rank
)
group_n
=
min
(
group_n
,
hidden_size
)
check_lora_expand_fp8_kernel
(
batches
=
batches
,
num_loras
=
num_loras
,
rank
=
rank
,
hidden_size
=
hidden_size
,
nslices
=
nslices
,
dtype
=
dtype
,
device
=
device
,
seq_length
=
128
,
add_inputs
=
True
,
quant_mode
=
quant_mode
,
group_k
=
group_k
,
group_n
=
group_n
,
)
vllm/lora/ops/triton_ops/__init__.py
View file @
b3ce711b
...
...
@@ -12,13 +12,17 @@ from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
fused_moe_lora_expand
,
fused_moe_lora_shrink
,
)
from
vllm.lora.ops.triton_ops.lora_expand_fp8_op
import
lora_expand_fp8
from
vllm.lora.ops.triton_ops.lora_expand_op
import
lora_expand
from
vllm.lora.ops.triton_ops.lora_kernel_metadata
import
LoRAKernelMeta
from
vllm.lora.ops.triton_ops.lora_shrink_fp8_op
import
lora_shrink_fp8
from
vllm.lora.ops.triton_ops.lora_shrink_op
import
lora_shrink
__all__
=
[
"lora_expand"
,
"lora_expand_fp8"
,
"lora_shrink"
,
"lora_shrink_fp8"
,
"LoRAKernelMeta"
,
"fused_moe_lora"
,
"fused_moe_lora_shrink"
,
...
...
vllm/lora/ops/triton_ops/fp8_kernel_utils.py
0 → 100644
View file @
b3ce711b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Utilities for Punica kernel construction.
"""
from
vllm.triton_utils
import
tl
,
triton
@
triton
.
jit
def
_accumulate_mm
(
tiled_a
,
tiled_b
,
accumulator
,
a_scale_ptr
,
b_scale_ptr
,
a_scale_k_stride
,
b_scale_k_stride
,
iter_k
,
group_k
:
tl
.
constexpr
,
group_n
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
):
"""
Core matrix multiplication and accumulation logic with quantization support.
Args:
tiled_a (tl.tensor): Loaded tile from A matrix
tiled_b (tl.tensor): Loaded tile from B matrix
accumulator (tl.tensor): Current accumulator value
a_scale_ptr (tl.tensor): Scale pointer for A matrix
b_scale_ptr (tl.tensor): Scale pointer for B matrix
a_scale_k_stride (int): K dimension stride for A's block-wise scales
b_scale_k_stride (int): K dimension stride for B's block-wise scales
iter_k (int): Current iteration's global K offset
group_k: Block size for K dimension in block-wise quantization
group_n: Block size for N dimension in block-wise quantization
use_fp8_w8a8: Whether using FP8 W8A8 quantization
"""
if
use_fp8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
# Block-wise quantization: scales are loaded per block
offs_ks
=
iter_k
//
group_k
# a_scale_ptr is (BLOCK_M,) tensor of base pointers per row
# Load scale for current K-group, result shape: (BLOCK_M,)
a_scale
=
tl
.
load
(
a_scale_ptr
+
offs_ks
*
a_scale_k_stride
)
# b_scale_ptr is (BLOCK_N,) tensor with N-offset pre-baked
# Load scale for current K-group, result shape: (BLOCK_N,)
b_scale
=
tl
.
load
(
b_scale_ptr
+
offs_ks
*
b_scale_k_stride
)
accumulator
+=
(
tl
.
dot
(
tiled_a
,
tiled_b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
)
else
:
# Tensor-wise or per-channel: accumulate and scale at end
accumulator
=
tl
.
dot
(
tiled_a
,
tiled_b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
tiled_a
,
tiled_b
)
return
accumulator
@
triton
.
jit
def
fp8_mm_k
(
a_ptr
,
b_ptr
,
a_scale_ptr
,
b_scale_ptr
,
ak_stride
,
bk_stride
,
a_scale_k_stride
,
b_scale_k_stride
,
offset_k
,
K
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
group_n
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
b_dtype
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
base_k
,
):
"""
FP8-compatible matrix multiplication kernel with quantization support.
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
B (k x n), iterate through the K dimension to compute the partial/complete
matrix block product with proper dequantization.
Args:
a_ptr (tl.tensor): Array of pointers, identifying rows of A
(FP8 or other dtype)
b_ptr (tl.tensor): Array of pointers, identifying columns of B
(FP8 dtype)
a_scale_ptr (tl.tensor): Scale pointer for A matrix
(per-token or block-wise)
b_scale_ptr (tl.tensor): Scale pointer for B matrix
(per-channel or block-wise)
ak_stride (int): K dimension stride of the A matrix
bk_stride (int): K dimension stride of the B matrix
a_scale_k_stride (int): K dimension stride for A's block-wise scales
b_scale_k_stride (int): K dimension stride for B's block-wise scales
offset_k (int): Base offset along K dimension
K: Length of the K dimension
BLOCK_M: M dimension of the output block m x n
BLOCK_N: N dimension of the output block m x n
BLOCK_K: K dimension atom
EVEN_K: True if the blocks of A and B can be loaded without masking
SPLIT_K: Parameter signifying parallelism in the K dimension
group_k: Block size for K dimension in block-wise quantization
group_n: Block size for N dimension in block-wise quantization
use_fp8_w8a8: Whether using FP8 W8A8 quantization
per_channel_quant: Whether using per-channel quantization
CAST_TYPE: if True, cast the values from the A matrix to the B
matrix dtype.
b_dtype: datatype of the B matrix
USE_GDC: Whether to use PDL. True indicates use.
base_k (int): Base offset along K dimension for current SPLIT_K group
"""
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
# Step size along K for each iteration
STEP_K
=
BLOCK_K
*
SPLIT_K
# Total number of iterations (compile-time constant)
num_iters
=
tl
.
cdiv
(
K
,
STEP_K
)
for
k
in
range
(
num_iters
):
# Current iteration's global K offset
iter_k
=
k
*
STEP_K
+
base_k
block_end
=
iter_k
+
BLOCK_K
# Skip iterations that are entirely past the K boundary
if
not
EVEN_K
and
iter_k
>=
K
:
pass
elif
EVEN_K
or
block_end
<=
K
:
# No masking needed: either K is evenly divisible (EVEN_K)
# or this block fits entirely within K
tiled_b
=
tl
.
load
(
b_ptr
)
if
USE_GDC
:
tl
.
extra
.
cuda
.
gdc_wait
()
tiled_a
=
tl
.
load
(
a_ptr
)
if
CAST_TYPE
:
tiled_a
=
tiled_a
.
to
(
b_dtype
)
accumulator
=
_accumulate_mm
(
tiled_a
,
tiled_b
,
accumulator
,
a_scale_ptr
,
b_scale_ptr
,
a_scale_k_stride
,
b_scale_k_stride
,
iter_k
,
group_k
,
group_n
,
use_fp8_w8a8
,
)
else
:
# Partial block at the tail: mask out-of-bounds elements
k_offsets
=
tl
.
arange
(
0
,
BLOCK_K
)
mask
=
iter_k
+
k_offsets
<
K
tiled_b
=
tl
.
load
(
b_ptr
,
mask
=
mask
[:,
None
],
other
=
0.0
)
if
USE_GDC
:
tl
.
extra
.
cuda
.
gdc_wait
()
tiled_a
=
tl
.
load
(
a_ptr
,
mask
=
mask
[
None
,
:],
other
=
0.0
)
if
CAST_TYPE
:
tiled_a
=
tiled_a
.
to
(
b_dtype
)
accumulator
=
_accumulate_mm
(
tiled_a
,
tiled_b
,
accumulator
,
a_scale_ptr
,
b_scale_ptr
,
a_scale_k_stride
,
b_scale_k_stride
,
iter_k
,
group_k
,
group_n
,
use_fp8_w8a8
,
)
a_ptr
+=
STEP_K
*
ak_stride
b_ptr
+=
STEP_K
*
bk_stride
return
accumulator
@
triton
.
jit
def
do_shrink_kernel_fp8
(
pid_n
,
pid_sk
,
slice_id
,
lora_index
,
input_ptr
,
lora_ptr
,
out_ptr
,
a_scale_ptr
,
b_scale_ptr
,
N
,
K
,
M_LEN
,
ram
,
# input strides
input_d0_stride
,
input_d1_stride
,
# lora strides
lora_d0_stride
,
lora_d1_stride
,
lora_d2_stride
,
# scale strides
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
# output strides
output_d0_stride
,
output_d1_stride
,
output_d2_stride
,
scaling
,
# block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
launch_pdl
:
tl
.
constexpr
,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice, compute the
matrix product and store in the appropriate output location.
"""
# Identify the lora_ptr from slice_id.
if
SLICE_NUM
==
1
:
cur_lora_ptr
=
lora_ptr
cur_b_scale_ptr
=
b_scale_ptr
else
:
cur_lora_ptr
=
(
tl
.
load
(
lora_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
tl
.
float8e4nv
))
if
b_scale_ptr
is
not
None
else
tl
.
load
(
lora_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
input_ptr
.
dtype
.
element_ty
)
)
)
cur_b_scale_ptr
=
(
tl
.
load
(
b_scale_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
tl
.
float32
))
if
b_scale_ptr
is
not
None
else
b_scale_ptr
)
# Identify the column indices of B to process.
offset_n
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_n
%
N
,
BLOCK_N
),
BLOCK_N
)
# Identify A and B block pointers
offset_k
=
pid_sk
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
a_ptr
=
(
input_ptr
+
ram
[:,
None
]
*
input_d0_stride
+
offset_k
[
None
,
:]
*
input_d1_stride
)
b_ptr
=
(
cur_lora_ptr
+
lora_d0_stride
*
lora_index
+
rbn
[
None
,
:]
*
lora_d1_stride
+
offset_k
[:,
None
]
*
lora_d2_stride
)
# Load scales for tensor-wise or per-channel quantization (outside the loop)
# Block-wise scales are loaded inside fp8_mm_k
if
use_fp8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
# Block-wise: compute scale pointers for fp8_mm_k
# a_scale: per-row base pointers, shape (BLOCK_M,)
# Each pointer points to the start of that row's scale data
mm_a_scale_ptr
=
a_scale_ptr
+
ram
*
a_scale_m_stride
# b_scale: pre-compute N-dimension offset
# We need to bake in the N-group offset since fp8_mm_k doesn't know pid_n
n_offset
=
pid_n
*
BLOCK_N
offs_ns
=
(
n_offset
+
tl
.
arange
(
0
,
BLOCK_N
))
//
group_n
# Base pointer with lora offset + N-group offset baked in, shape (BLOCK_N,)
mm_b_scale_ptr
=
(
cur_b_scale_ptr
+
lora_index
*
b_scale_l_stride
+
offs_ns
*
b_scale_n_stride
)
elif
per_channel_quant
:
# Per-channel for weights, per-token for activations
b_scale_ptrs
=
(
cur_b_scale_ptr
+
lora_index
*
b_scale_l_stride
+
rbn
*
b_scale_n_stride
)
b_scale
=
tl
.
load
(
b_scale_ptrs
)
# Per-token activation scale
a_scale
=
tl
.
load
(
a_scale_ptr
+
ram
*
a_scale_m_stride
)[:,
None
]
# For non-block-wise, pass original pointers (not used in mm loop)
mm_a_scale_ptr
=
a_scale_ptr
mm_b_scale_ptr
=
cur_b_scale_ptr
else
:
# Tensor-wise quantization
a_scale
=
tl
.
load
(
a_scale_ptr
)
if
a_scale_ptr
is
not
None
else
1.0
b_scale
=
tl
.
load
(
cur_b_scale_ptr
+
lora_index
*
b_scale_l_stride
)
# For non-block-wise, pass original pointers (not used in mm loop)
mm_a_scale_ptr
=
a_scale_ptr
mm_b_scale_ptr
=
cur_b_scale_ptr
else
:
# Non-quantized path
mm_a_scale_ptr
=
a_scale_ptr
mm_b_scale_ptr
=
cur_b_scale_ptr
# Compute partial/complete block matrix product.
accumulator
=
fp8_mm_k
(
a_ptr
,
b_ptr
,
mm_a_scale_ptr
,
mm_b_scale_ptr
,
input_d1_stride
,
lora_d2_stride
,
a_scale_k_stride
,
b_scale_k_stride
,
offset_k
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
SPLIT_K
,
group_k
,
group_n
,
use_fp8_w8a8
,
per_channel_quant
,
False
,
cur_lora_ptr
.
dtype
.
element_ty
,
USE_GDC
,
base_k
=
pid_sk
*
BLOCK_K
,
)
# GDC launch dependents hints the runtime system to launch dependent kernels.
if
USE_GDC
:
tl
.
extra
.
cuda
.
gdc_launch_dependents
()
# Apply dequantization scales for tensor-wise/per-channel quantization
if
use_fp8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
# Block-wise: already applied in fp8_mm_k
pass
else
:
# Tensor-wise or per-channel: apply scales after accumulation
accumulator
=
accumulator
*
a_scale
*
b_scale
# Apply LoRA scaling factor
accumulator
*=
scaling
# Identify the C output pointers to store the results of the accumulator.
offset_cn
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
offset_cm
=
tl
.
arange
(
0
,
BLOCK_M
)
cur_out_ptr
=
out_ptr
if
SLICE_NUM
==
1
else
out_ptr
+
slice_id
*
output_d0_stride
c_ptr
=
(
cur_out_ptr
+
ram
[:,
None
]
*
output_d1_stride
+
offset_cn
[
None
,
:]
*
output_d2_stride
)
c_mask
=
(
offset_cm
[:,
None
]
<
M_LEN
)
&
(
offset_cn
[
None
,
:]
<
N
)
# Cast accumulator to output dtype
accumulator
=
accumulator
.
to
(
out_ptr
.
dtype
.
element_ty
)
# handles write-back with reduction-splitting
if
SPLIT_K
==
1
:
tl
.
store
(
c_ptr
,
accumulator
,
mask
=
c_mask
)
else
:
tl
.
atomic_add
(
c_ptr
,
accumulator
,
mask
=
c_mask
,
sem
=
"relaxed"
)
@
triton
.
jit
def
do_expand_kernel_fp8
(
pid_n
,
lora_index
,
slice_id
,
input_ptr
,
lora_ptr
,
out_ptr
,
a_scale_ptr
,
b_scale_ptr
,
N
,
K
,
M_LEN
,
ram
,
# array identifying the rows of Input ptr to operate on
slice_start_loc
,
# input ptr strides
input_d0_stride
,
input_d1_stride
,
input_d2_stride
,
# lora ptr strides
ls_d0_ptr
,
ls_d1_ptr
,
ls_d2_ptr
,
# scale strides
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
# out ptr strides
output_d0_stride
,
output_d1_stride
,
# block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# constants
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
SAME_STRIDE
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
ADD_INPUTS
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
):
"""
FP8-compatible expand kernel for LoRA.
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice,
compute the matrix product with FP8 quantization support and store in
the appropriate output location.
For expand kernel, the input (shrink output) may be in FP32/FP16/BF16,
while the LoRA B weights can be in FP8.
Supports:
- FP8 W8A8 quantization for LoRA B weights
- Block-wise quantization with configurable group_k and group_n
- Per-channel quantization
- Tensor-wise quantization
"""
# ls_d*_ptr can be either an integer or a pointer
if
SAME_STRIDE
:
cur_lora_d0_stride
=
ls_d0_ptr
cur_lora_d1_stride
=
ls_d1_ptr
cur_lora_d2_stride
=
ls_d2_ptr
else
:
cur_lora_d0_stride
=
tl
.
load
(
ls_d0_ptr
+
slice_id
)
cur_lora_d1_stride
=
tl
.
load
(
ls_d1_ptr
+
slice_id
)
cur_lora_d2_stride
=
tl
.
load
(
ls_d2_ptr
+
slice_id
)
# Identify the input_ptr and lora_ptr from slice_id.
if
SLICE_NUM
==
1
:
cur_input_ptr
=
input_ptr
if
use_fp8_w8a8
:
cur_lora_ptr
=
lora_ptr
cur_b_scale_ptr
=
b_scale_ptr
else
:
cur_lora_ptr
=
lora_ptr
cur_b_scale_ptr
=
b_scale_ptr
# May be None for non-quantized
else
:
cur_input_ptr
=
input_ptr
+
slice_id
*
input_d0_stride
if
use_fp8_w8a8
:
cur_lora_ptr
=
tl
.
load
(
lora_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
tl
.
float8e4nv
)
)
cur_b_scale_ptr
=
tl
.
load
(
b_scale_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
tl
.
float32
)
)
else
:
cur_lora_ptr
=
tl
.
load
(
lora_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
out_ptr
.
dtype
.
element_ty
)
)
cur_b_scale_ptr
=
(
tl
.
load
(
b_scale_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
tl
.
float32
))
if
b_scale_ptr
is
not
None
else
None
)
# Identify the column indices of B to process.
offset_n
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_n
%
N
,
BLOCK_N
),
BLOCK_N
)
# Identify A and B block pointers
offset_k
=
tl
.
arange
(
0
,
BLOCK_K
)
a_ptr
=
(
cur_input_ptr
+
ram
[:,
None
]
*
input_d1_stride
+
offset_k
[
None
,
:]
*
input_d2_stride
)
b_ptr
=
(
cur_lora_ptr
+
cur_lora_d0_stride
*
lora_index
+
offset_k
[:,
None
]
*
cur_lora_d2_stride
+
rbn
[
None
,
:]
*
cur_lora_d1_stride
)
# Setup scale pointers for FP8/INT8 quantization
if
use_fp8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
# Block-wise quantization - compute scale pointers for fp8_mm_k
# a_scale: per-row base pointers, shape (BLOCK_M,)
mm_a_scale_ptr
=
a_scale_ptr
+
ram
*
a_scale_m_stride
# b_scale: pre-compute N-dimension offset since fp8_mm_k doesn't know pid_n
n_offset
=
pid_n
*
BLOCK_N
offs_ns
=
(
n_offset
+
tl
.
arange
(
0
,
BLOCK_N
))
//
group_n
# Base pointer with lora offset + N-group offset baked in, shape (BLOCK_N,)
mm_b_scale_ptr
=
(
cur_b_scale_ptr
+
lora_index
*
b_scale_l_stride
+
offs_ns
*
b_scale_n_stride
)
elif
per_channel_quant
:
# Per-channel for weights, shape (BLOCK_N,)
b_scale_ptrs
=
(
cur_b_scale_ptr
+
lora_index
*
b_scale_l_stride
+
rbn
*
b_scale_n_stride
)
b_scale
=
tl
.
load
(
b_scale_ptrs
)
# Per-token activation scale, only if a_scale_ptr provided
a_scale
=
tl
.
load
(
a_scale_ptr
+
ram
*
a_scale_m_stride
)[:,
None
]
# For non-block-wise, pass original pointers (not used in mm loop)
mm_a_scale_ptr
=
a_scale_ptr
mm_b_scale_ptr
=
cur_b_scale_ptr
else
:
# Tensor-wise quantization
a_scale
=
tl
.
load
(
a_scale_ptr
)
if
a_scale_ptr
is
not
None
else
1.0
b_scale
=
tl
.
load
(
cur_b_scale_ptr
+
lora_index
*
b_scale_l_stride
)
# For non-block-wise, pass original pointers (not used in mm loop)
mm_a_scale_ptr
=
a_scale_ptr
mm_b_scale_ptr
=
cur_b_scale_ptr
else
:
# Non-quantized path
mm_a_scale_ptr
=
a_scale_ptr
mm_b_scale_ptr
=
cur_b_scale_ptr
# Compute the block matrix product using fp8_mm_k
# Note: For expand kernel, SPLIT_K=1, so we pass 1 for SPLIT_K
accumulator
=
fp8_mm_k
(
a_ptr
,
b_ptr
,
mm_a_scale_ptr
,
mm_b_scale_ptr
,
input_d2_stride
,
# ak_stride
cur_lora_d2_stride
,
# bk_stride
a_scale_k_stride
,
b_scale_k_stride
,
offset_k
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
1
,
# SPLIT_K = 1 for expand kernel
group_k
,
group_n
,
use_fp8_w8a8
,
per_channel_quant
,
CAST_TYPE
,
# CAST_TYPE - cast FP8 B to A's dtype
cur_lora_ptr
.
dtype
.
element_ty
,
USE_GDC
,
base_k
=
0
,
)
# Apply dequantization scales for non-block-wise quantization
if
use_fp8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
pass
# Already applied per block in fp8_mm_k
else
:
# Tensor-wise or per-channel: apply scales after accumulation
accumulator
=
accumulator
*
a_scale
*
b_scale
tiled_c
=
accumulator
.
to
(
out_ptr
.
dtype
.
element_ty
)
if
SLICE_NUM
==
1
:
cur_slice_start
=
slice_start_loc
else
:
cur_slice_start
=
tl
.
load
(
slice_start_loc
+
slice_id
)
# Identify the C output pointers to store the results of the accumulator.
offset_cn
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
+
cur_slice_start
offset_cm
=
tl
.
arange
(
0
,
BLOCK_M
)
c_ptr
=
(
out_ptr
+
ram
[:,
None
]
*
output_d0_stride
+
offset_cn
[
None
,
:]
*
output_d1_stride
)
c_mask
=
(
offset_cm
[:,
None
]
<
M_LEN
)
&
(
offset_cn
[
None
,
:]
<
(
cur_slice_start
+
N
))
if
ADD_INPUTS
:
tiled_out
=
tl
.
load
(
c_ptr
,
mask
=
c_mask
)
tiled_c
+=
tiled_out
tl
.
store
(
c_ptr
,
tiled_c
,
mask
=
c_mask
)
vllm/lora/ops/triton_ops/lora_expand_fp8_op.py
0 → 100644
View file @
b3ce711b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import
torch
from
vllm.lora.ops.triton_ops.fp8_kernel_utils
import
do_expand_kernel_fp8
from
vllm.lora.ops.triton_ops.utils
import
(
_get_lora_b_ptr
,
get_lora_op_configs
,
)
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
direct_register_custom_op
_EXPAND_LORA_SCALE_PTR_DICT
:
dict
[
tuple
[
int
,
...],
torch
.
tensor
]
=
{}
def
_get_expand_lora_scale_ptr
(
lora_weights
:
list
[
torch
.
Tensor
],
device
:
torch
.
device
):
"""
`_EXPAND_LORA_SCALE_PTR_DICT` collects the required information during
`profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key
=
tuple
(
lora_weight
.
data_ptr
()
for
lora_weight
in
lora_weights
)
if
(
ptr_tensor
:
=
_EXPAND_LORA_SCALE_PTR_DICT
.
get
(
key
))
is
not
None
:
return
ptr_tensor
if
len
(
lora_weights
)
>
1
:
tensor_ptrs
=
[]
for
lora_weight
in
lora_weights
:
tensor_ptrs
.
append
(
lora_weight
.
data_ptr
())
ptr_tensor
=
torch
.
tensor
(
tensor_ptrs
,
device
=
device
,
dtype
=
torch
.
uint64
)
else
:
# Single slice: return the actual tensor so the kernel can use it
# directly without pointer indirection (matches SLICE_NUM == 1 path).
ptr_tensor
=
lora_weights
[
0
]
_EXPAND_LORA_SCALE_PTR_DICT
[
key
]
=
ptr_tensor
return
_EXPAND_LORA_SCALE_PTR_DICT
.
get
(
key
)
@
triton
.
jit
def
_lora_expand_kernel_fp8
(
input_ptr
,
lora_ptr
,
out_ptr
,
a_scale_ptr
,
b_scale_ptr
,
M
,
N
,
K
,
token_indices_sorted_by_lora_ids
,
num_tokens_per_lora
,
lora_token_start_loc
,
lora_ids
,
slice_start_loc
,
input_d0_stride
,
input_d1_stride
,
input_d2_stride
,
ls_d0_ptr
,
ls_d1_ptr
,
ls_d2_ptr
,
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
output_d0_stride
,
output_d1_stride
,
output_hs_ptr
,
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ADD_INPUTS
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
SAME_STRIDE
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
launch_pdl
:
tl
.
constexpr
,
):
"""
FP8-compatible expand kernel wrapper.
"""
cta_n_num
=
tl
.
cdiv
(
N
,
BLOCK_N
)
cta_m_num
=
tl
.
cdiv
(
M
,
BLOCK_M
)
pid_mn
=
tl
.
program_id
(
axis
=
0
)
pid_m
=
pid_mn
%
cta_m_num
pid_n
=
(
pid_mn
//
cta_m_num
)
%
cta_n_num
slice_id
=
tl
.
program_id
(
axis
=
1
)
lora_idx
=
tl
.
program_id
(
axis
=
2
)
lora_id
=
tl
.
load
(
lora_ids
+
lora_idx
)
if
lora_id
==
-
1
:
return
lora_m_size
=
tl
.
load
(
num_tokens_per_lora
+
lora_idx
)
cta_m_offset
=
pid_m
*
BLOCK_M
if
cta_m_offset
>=
lora_m_size
:
return
curr_N
=
N
if
SAME_STRIDE
else
tl
.
load
(
output_hs_ptr
+
slice_id
)
if
pid_n
*
BLOCK_N
>=
curr_N
:
return
cta_m_len
=
min
(
BLOCK_M
,
lora_m_size
-
cta_m_offset
)
lora_m_indices_start
=
tl
.
load
(
lora_token_start_loc
+
lora_idx
)
cta_lora_seq_indices
=
(
token_indices_sorted_by_lora_ids
+
lora_m_indices_start
+
cta_m_offset
)
offset_m
=
tl
.
arange
(
0
,
BLOCK_M
)
%
cta_m_len
ram
=
tl
.
load
(
cta_lora_seq_indices
+
offset_m
)
do_expand_kernel_fp8
(
pid_n
,
lora_id
,
slice_id
,
input_ptr
,
lora_ptr
,
out_ptr
,
a_scale_ptr
,
b_scale_ptr
,
curr_N
,
K
,
cta_m_len
,
ram
,
slice_start_loc
,
input_d0_stride
,
input_d1_stride
,
input_d2_stride
,
ls_d0_ptr
,
ls_d1_ptr
,
ls_d2_ptr
,
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
output_d0_stride
,
output_d1_stride
,
group_n
,
group_k
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
SAME_STRIDE
,
SLICE_NUM
,
EVEN_K
,
CAST_TYPE
,
ADD_INPUTS
,
USE_GDC
,
use_fp8_w8a8
,
per_channel_quant
,
)
@
torch
.
inference_mode
()
def
_lora_expand_fp8
(
inputs
:
torch
.
Tensor
,
# shape [num_slices, num_tokens, lora_rank]
lora_b_weights
:
list
[
torch
.
Tensor
],
# FP8 [num_lora, hidden_size, lora_rank]
output_tensor
:
torch
.
Tensor
,
# shape [num_tokens, hidden_size * num_slices]
token_lora_mapping
:
torch
.
Tensor
,
token_indices_sorted_by_lora_ids
:
torch
.
Tensor
,
num_tokens_per_lora
:
torch
.
Tensor
,
lora_token_start_loc
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
no_lora_flag_cpu
:
torch
.
Tensor
,
# shape [1]
num_active_loras
:
int
,
# number of active LoRAs (unused here, for API compat)
b_scale
:
list
[
torch
.
Tensor
],
# LoRA B weight scale per slice
a_scale
:
torch
.
Tensor
|
None
=
None
,
# Scale for shrink output (optional)
offset_start
:
int
=
0
,
add_inputs
:
bool
=
False
,
group_k
:
int
=
0
,
group_n
:
int
=
0
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
)
->
None
:
"""
FP8-compatible LoRA expand operation.
Args:
inputs: Input tensor from shrink operation [num_slices, num_tokens, lora_rank]
lora_b_weights: List of FP8 LoRA B weights per slice
output_tensor: Output tensor
a_scale: Optional scale for input (if input is quantized)
b_scale: Weight quantization scales per slice
token_lora_mapping: Token to LoRA ID mapping
token_indices_sorted_by_lora_ids: Sorted token indices
num_tokens_per_lora: Number of tokens per LoRA
lora_token_start_loc: Start location for each LoRA's tokens
lora_ids: LoRA IDs to process
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
group_k (int, optional): Block size for K in block-wise quantization.
group_n (int, optional): Block size for N in block-wise quantization.
use_fp8_w8a8 (bool, optional): Whether to use FP8 W8A8 quantization.
per_channel_quant (bool, optional): Whether to use per-channel quantization.
"""
assert
no_lora_flag_cpu
.
numel
()
==
1
if
no_lora_flag_cpu
.
item
():
# None of the inputs require LoRA.
return
if
use_fp8_w8a8
:
assert
inputs
.
dtype
in
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
,
]
for
weight
in
lora_b_weights
:
assert
weight
.
dtype
in
[
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
,
]
else
:
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
for
weight
in
lora_b_weights
:
assert
weight
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
inputs
.
size
(
0
)
==
len
(
lora_b_weights
)
assert
output_tensor
.
is_contiguous
()
# metadata sanity check.
M
=
inputs
.
size
(
1
)
assert
token_lora_mapping
.
size
(
0
)
==
M
assert
token_lora_mapping
.
size
(
0
)
==
token_indices_sorted_by_lora_ids
.
size
(
0
)
assert
lora_ids
.
size
(
0
)
==
num_tokens_per_lora
.
size
(
0
)
assert
lora_token_start_loc
.
size
(
0
)
==
lora_ids
.
size
(
0
)
+
1
(
slice_start_tensor
,
lora_ptr_tensor
,
lora_strides_d0_tensor
,
lora_strides_d1_tensor
,
lora_strides_d2_tensor
,
hidden_sizes_tensor
,
same_stride
,
MAX_N
,
)
=
_get_lora_b_ptr
(
lora_b_weights
,
offset_start
,
inputs
.
device
)
# Get scale pointers
if
b_scale
is
not
None
:
b_scale_ptr_tensor
=
_get_expand_lora_scale_ptr
(
b_scale
,
inputs
.
device
)
else
:
b_scale_ptr_tensor
=
None
K
=
lora_b_weights
[
0
].
shape
[
-
1
]
ADD_INPUTS
=
add_inputs
MAX_LORAS
=
lora_ids
.
size
(
0
)
CAST_TYPE
=
False
NUM_SLICES
=
len
(
lora_b_weights
)
# Triton kernel configs.
kernel_config
=
get_lora_op_configs
(
op_type
=
"expand"
,
max_loras
=
MAX_LORAS
,
batch
=
M
,
hidden_size
=
MAX_N
,
rank
=
K
,
num_slices
=
NUM_SLICES
,
add_inputs
=
add_inputs
,
)
BLOCK_M
=
kernel_config
[
"block_m"
]
BLOCK_N
=
kernel_config
[
"block_n"
]
BLOCK_K
=
kernel_config
[
"block_k"
]
NUM_WARPS
=
kernel_config
[
"num_warps"
]
NUM_CTAS
=
kernel_config
.
get
(
"num_ctas"
,
1
)
NUM_STAGES
=
kernel_config
[
"num_stages"
]
EVEN_K
=
K
%
BLOCK_K
==
0
grid
=
(
triton
.
cdiv
(
M
,
BLOCK_M
)
*
triton
.
cdiv
(
MAX_N
,
BLOCK_N
),
NUM_SLICES
,
num_active_loras
,
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
use_gdc
=
False
# supports_pdl(inputs.device)
# Get scale strides
if
a_scale
is
not
None
:
a_scale_m_stride
=
a_scale
.
stride
(
0
)
if
a_scale
.
dim
()
>
1
else
0
a_scale_k_stride
=
a_scale
.
stride
(
-
1
)
if
a_scale
.
dim
()
>
1
else
0
else
:
a_scale_m_stride
=
0
a_scale_k_stride
=
0
if
b_scale
is
not
None
and
b_scale
[
0
].
dim
()
>
0
:
b_scale_l_stride
=
b_scale
[
0
].
stride
(
0
)
if
b_scale
[
0
].
dim
()
>
0
else
0
b_scale_n_stride
=
(
b_scale
[
0
].
stride
(
-
2
)
if
b_scale
[
0
].
dim
()
>
2
else
(
b_scale
[
0
].
stride
(
-
1
)
if
b_scale
[
0
].
dim
()
>
1
else
1
)
)
b_scale_k_stride
=
b_scale
[
0
].
stride
(
-
1
)
if
b_scale
[
0
].
dim
()
>
2
else
0
else
:
b_scale_l_stride
=
1
b_scale_n_stride
=
0
b_scale_k_stride
=
0
_lora_expand_kernel_fp8
[
grid
](
inputs
,
lora_ptr_tensor
,
output_tensor
,
a_scale
,
b_scale_ptr_tensor
,
M
,
MAX_N
,
K
,
token_indices_sorted_by_lora_ids
,
num_tokens_per_lora
,
lora_token_start_loc
,
lora_ids
,
slice_start_tensor
,
inputs
.
stride
(
0
),
inputs
.
stride
(
1
),
inputs
.
stride
(
2
),
lora_strides_d0_tensor
,
lora_strides_d1_tensor
,
lora_strides_d2_tensor
,
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
hidden_sizes_tensor
,
group_n
,
group_k
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
ADD_INPUTS
,
CAST_TYPE
,
NUM_SLICES
,
same_stride
,
use_gdc
,
use_fp8_w8a8
=
use_fp8_w8a8
,
per_channel_quant
=
per_channel_quant
,
num_warps
=
NUM_WARPS
,
num_ctas
=
NUM_CTAS
,
num_stages
=
NUM_STAGES
,
launch_pdl
=
use_gdc
,
)
return
def
_lora_expand_fp8_fake
(
inputs
:
torch
.
Tensor
,
lora_b_weights
:
list
[
torch
.
Tensor
],
output_tensor
:
torch
.
Tensor
,
token_lora_mapping
:
torch
.
Tensor
,
token_indices_sorted_by_lora_ids
:
torch
.
Tensor
,
num_tokens_per_lora
:
torch
.
Tensor
,
lora_token_start_loc
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
no_lora_flag_cpu
:
torch
.
Tensor
,
num_active_loras
:
int
,
b_scale
:
list
[
torch
.
Tensor
],
a_scale
:
torch
.
Tensor
|
None
=
None
,
offset_start
:
int
=
0
,
add_inputs
:
bool
=
False
,
group_k
:
int
=
0
,
group_n
:
int
=
0
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
)
->
None
:
return
try
:
direct_register_custom_op
(
op_name
=
"lora_expand_fp8"
,
op_func
=
_lora_expand_fp8
,
mutates_args
=
[
"output_tensor"
],
fake_impl
=
_lora_expand_fp8_fake
,
)
lora_expand_fp8
=
torch
.
ops
.
vllm
.
lora_expand_fp8
except
AttributeError
:
lora_expand_fp8
=
_lora_expand_fp8
vllm/lora/ops/triton_ops/lora_shrink_fp8_op.py
0 → 100644
View file @
b3ce711b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import
torch
from
vllm.lora.ops.triton_ops.fp8_kernel_utils
import
do_shrink_kernel_fp8
from
vllm.lora.ops.triton_ops.utils
import
_get_lora_a_ptr
,
get_lora_op_configs
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
direct_register_custom_op
_SHRINK_LORA_SCALE_PTR_DICT
:
dict
[
tuple
[
int
,
...],
tuple
]
=
{}
def
_get_shrink_lora_scale_ptr
(
lora_scale_weights
:
list
[
torch
.
Tensor
],
device
:
torch
.
device
):
"""
`_SHRINK_LORA_SCALE_PTR_DICT` collects the required information during
`profile_run`. After this, it remains constant and subsequent usage is
through LUT.
Returns a tuple of (scale_ptr_tensor, l_stride, n_stride, k_stride).
Supports scale tensors of varying dimensionality:
- 1D: (lora_num,) — tensor-wise quantization
- 2D: (lora_num, N) — per-channel quantization
- 3D: (lora_num, N, K) — block-wise quantization
- 4D: (lora_num, 1, N, K) — block-wise with extra dim (squeezed to 3D)
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key
=
tuple
(
lora_weight
.
data_ptr
()
for
lora_weight
in
lora_scale_weights
)
if
values
:
=
_SHRINK_LORA_SCALE_PTR_DICT
.
get
(
key
):
return
values
tensor_ptrs
=
[]
scale_l_strides
=
[]
scale_n_strides
=
[]
scale_k_strides
=
[]
for
lora_scale_weight
in
lora_scale_weights
:
if
lora_scale_weight
.
ndim
==
4
:
# shape:(lora_num,1,size,rank)
assert
lora_scale_weight
.
size
(
1
)
==
1
lora_scale_weight
=
lora_scale_weight
.
squeeze
(
dim
=
1
)
assert
1
<=
lora_scale_weight
.
ndim
<=
3
assert
lora_scale_weight
.
is_contiguous
()
tensor_ptrs
.
append
(
lora_scale_weight
.
data_ptr
())
scale_l_strides
.
append
(
lora_scale_weight
.
stride
(
0
)
if
lora_scale_weight
.
ndim
>
0
else
0
)
scale_n_strides
.
append
(
lora_scale_weight
.
stride
(
-
2
)
if
lora_scale_weight
.
ndim
>
2
else
(
lora_scale_weight
.
stride
(
-
1
)
if
lora_scale_weight
.
ndim
>
1
else
1
)
)
scale_k_strides
.
append
(
lora_scale_weight
.
stride
(
-
1
)
if
lora_scale_weight
.
ndim
>
2
else
0
)
if
len
(
lora_scale_weights
)
>
1
:
scale_ptr_tensor
=
torch
.
tensor
(
tensor_ptrs
,
device
=
device
,
dtype
=
torch
.
uint64
)
else
:
scale_ptr_tensor
=
lora_scale_weights
[
0
]
if
(
len
(
set
(
scale_l_strides
))
>
1
or
len
(
set
(
scale_n_strides
))
>
1
or
len
(
set
(
scale_k_strides
))
>
1
):
raise
ValueError
(
"All LoRA scale weights must have the same stride."
)
_SHRINK_LORA_SCALE_PTR_DICT
[
key
]
=
(
scale_ptr_tensor
,
scale_l_strides
[
0
],
scale_n_strides
[
0
],
scale_k_strides
[
0
],
)
return
_SHRINK_LORA_SCALE_PTR_DICT
.
get
(
key
)
@
triton
.
jit
def
_lora_shrink_kernel_fp8
(
input_ptr
,
lora_ptr
,
out_ptr
,
a_scale_ptr
,
b_scale_ptr
,
M
,
N
,
K
,
token_indices_sorted_by_lora_ids
,
num_tokens_per_lora
,
lora_token_start_loc
,
lora_ids
,
scaling
,
input_d0_stride
,
input_d1_stride
,
lora_d0_stride
,
lora_d1_stride
,
lora_d2_stride
,
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
output_d0_stride
,
output_d1_stride
,
output_d2_stride
,
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
USE_GDC
:
tl
.
constexpr
,
## should always be false in shrink kernel
use_fp8_w8a8
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
launch_pdl
:
tl
.
constexpr
,
):
cta_n_num
=
tl
.
cdiv
(
N
,
BLOCK_N
)
cta_m_num
=
tl
.
cdiv
(
M
,
BLOCK_M
)
pid_sk_m_n
=
tl
.
program_id
(
axis
=
0
)
pid_sk
=
pid_sk_m_n
%
SPLIT_K
pid_m_n
=
pid_sk_m_n
//
SPLIT_K
num_pid_in_group
=
GROUP_SIZE_M
*
cta_n_num
group_id
=
pid_m_n
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
cta_m_num
-
first_pid_m
,
GROUP_SIZE_M
)
# Column-major ordering within groups for better cache reuse
pid_m
=
first_pid_m
+
((
pid_m_n
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid_m_n
%
num_pid_in_group
)
//
group_size_m
slice_id
=
tl
.
program_id
(
axis
=
1
)
lora_idx
=
tl
.
program_id
(
axis
=
2
)
lora_id
=
tl
.
load
(
lora_ids
+
lora_idx
)
if
lora_id
==
-
1
:
# Early exit for the no-lora case.
return
lora_m_size
=
tl
.
load
(
num_tokens_per_lora
+
lora_idx
)
cta_m_offset
=
pid_m
*
BLOCK_M
if
cta_m_offset
>=
lora_m_size
:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len
=
min
(
BLOCK_M
,
lora_m_size
-
cta_m_offset
)
# Identify all rows that this CTA should process.
lora_m_indices_start
=
tl
.
load
(
lora_token_start_loc
+
lora_idx
)
cta_lora_seq_indices
=
(
token_indices_sorted_by_lora_ids
+
lora_m_indices_start
+
cta_m_offset
)
# Load all relevant row indices.
offset_m
=
tl
.
arange
(
0
,
BLOCK_M
)
%
cta_m_len
ram
=
tl
.
load
(
cta_lora_seq_indices
+
offset_m
)
do_shrink_kernel_fp8
(
pid_n
,
pid_sk
,
slice_id
,
lora_id
,
input_ptr
,
lora_ptr
,
out_ptr
,
a_scale_ptr
,
b_scale_ptr
,
N
,
K
,
cta_m_len
,
ram
,
# array identifying the rows of Input ptr to operate on
# input strides
input_d0_stride
,
input_d1_stride
,
# lora strides
lora_d0_stride
,
lora_d1_stride
,
lora_d2_stride
,
# scale strides
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
# output strides
output_d0_stride
,
output_d1_stride
,
output_d2_stride
,
scaling
,
# block size for block-wise quantization
group_n
,
group_k
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
SPLIT_K
,
SLICE_NUM
,
USE_GDC
,
use_fp8_w8a8
,
per_channel_quant
,
launch_pdl
,
)
@
torch
.
inference_mode
()
def
_lora_shrink_fp8
(
inputs
:
torch
.
Tensor
,
# shape [num_tokens, hidden_size] - FP8 or FP16/BF16
lora_a_weights
:
list
[
torch
.
Tensor
],
# shape [num_loras, lora_rank, hidden_size] - FP8 or FP16/BF16
output_tensor
:
torch
.
Tensor
,
# shape [num_slices, num_tokens, lora_rank]
token_lora_mapping
:
torch
.
Tensor
,
# shape [num_tokens]
token_indices_sorted_by_lora_ids
:
torch
.
Tensor
,
# shape [num_tokens]
num_tokens_per_lora
:
torch
.
Tensor
,
# shape [max-loras + 1]
lora_token_start_loc
:
torch
.
Tensor
,
# shape [max-loras + 2]
lora_ids
:
torch
.
Tensor
,
# shape [max-loras + 1]
no_lora_flag_cpu
:
torch
.
Tensor
,
# shape [1]
num_active_loras
:
int
,
# number of active LoRAs (unused here, for API compat)
scaling
:
float
,
b_scale
:
list
[
torch
.
Tensor
],
# LoRA weight scale per slice
a_scale
:
torch
.
Tensor
|
None
=
None
,
# Activation scale - per-token or block-wise
group_k
:
int
=
0
,
# Block size for K in block-wise quantization (0 = tensor-wise)
group_n
:
int
=
0
,
# Block size for N in block-wise quantization
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
)
->
None
:
"""
Args:
inputs: FP8 or FP16/BF16 input tensor [num_tokens, hidden_size]
lora_a_weights: List of FP8 or FP16/BF16 LoRA A weights per slice
output_tensor: Output tensor (FP16/BF16/FP32)
token_lora_mapping: Token to LoRA ID mapping
token_indices_sorted_by_lora_ids: Sorted token indices
num_tokens_per_lora: Number of tokens per LoRA
lora_token_start_loc: Start location for each LoRA's tokens
lora_ids: LoRA IDs to process
scaling: LoRA scaling factor
a_scale: Activation quantization scales
b_scale: Weight quantization scales per slice
group_k: Block size for K dimension quantization
group_n: Block size for N dimension quantization
use_fp8_w8a8: Whether to use FP8 weights and activations
per_channel_quant: Whether to use per-channel quantization
"""
assert
no_lora_flag_cpu
.
numel
()
==
1
if
no_lora_flag_cpu
.
item
():
# None of the inputs require LoRA.
return
assert
inputs
.
size
(
1
)
==
lora_a_weights
[
0
].
size
(
-
1
)
assert
inputs
.
is_contiguous
()
assert
output_tensor
.
is_contiguous
()
# metadata sanity check
M
=
inputs
.
size
(
0
)
assert
token_lora_mapping
.
size
(
0
)
==
M
assert
token_lora_mapping
.
size
(
0
)
==
token_indices_sorted_by_lora_ids
.
size
(
0
)
assert
lora_ids
.
size
(
0
)
==
num_tokens_per_lora
.
size
(
0
)
assert
lora_token_start_loc
.
size
(
0
)
==
lora_ids
.
size
(
0
)
+
1
output_tensor
.
zero_
()
# Get LoRA weight pointers
(
lora_ptr_tensor
,
lora_strides_d0
,
lora_strides_d1
,
lora_strides_d2
)
=
(
_get_lora_a_ptr
(
lora_a_weights
,
inputs
.
device
)
)
# Get scale pointers if using FP8
if
use_fp8_w8a8
:
assert
a_scale
is
not
None
,
"a_scale required for FP8 w8a8"
assert
b_scale
is
not
None
,
"b_scale required for FP8"
b_scale_ptr_tensor
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
=
(
_get_shrink_lora_scale_ptr
(
b_scale
,
inputs
.
device
)
)
a_scale_ptr
=
(
a_scale
if
a_scale
is
not
None
else
torch
.
tensor
(
1.0
,
device
=
inputs
.
device
)
)
else
:
b_scale_ptr_tensor
=
torch
.
tensor
(
0
,
device
=
inputs
.
device
)
b_scale_l_stride
=
0
b_scale_n_stride
=
0
b_scale_k_stride
=
0
a_scale_ptr
=
torch
.
tensor
(
0
,
device
=
inputs
.
device
)
N
,
K
=
lora_a_weights
[
0
].
shape
[
-
2
:]
# K=hidden_size, N=rank
NUM_SLICES
=
len
(
lora_a_weights
)
MAX_LORAS
=
lora_ids
.
size
(
0
)
# Triton kernel configs
kernel_config
=
get_lora_op_configs
(
"shrink"
,
max_loras
=
MAX_LORAS
,
batch
=
M
,
hidden_size
=
K
,
rank
=
N
,
num_slices
=
NUM_SLICES
,
)
BLOCK_M
=
kernel_config
[
"block_m"
]
BLOCK_N
=
kernel_config
[
"block_n"
]
BLOCK_K
=
kernel_config
[
"block_k"
]
SPLIT_K
=
kernel_config
[
"split_k"
]
NUM_WARPS
=
kernel_config
[
"num_warps"
]
NUM_STAGES
=
kernel_config
[
"num_stages"
]
NUM_CTAS
=
kernel_config
[
"num_ctas"
]
GROUP_SIZE_M
=
kernel_config
.
get
(
"group_size_m"
,
8
)
assert
BLOCK_K
is
not
None
and
SPLIT_K
is
not
None
EVEN_K
=
K
%
(
BLOCK_K
*
SPLIT_K
)
==
0
# Grid configuration with column-major ordering support
grid
=
(
SPLIT_K
*
triton
.
cdiv
(
M
,
BLOCK_M
)
*
triton
.
cdiv
(
N
,
BLOCK_N
),
NUM_SLICES
,
num_active_loras
,
)
# Determine scale strides
if
use_fp8_w8a8
:
if
a_scale
is
not
None
and
a_scale
.
ndim
==
2
:
a_scale_m_stride
=
a_scale
.
stride
(
0
)
a_scale_k_stride
=
a_scale
.
stride
(
1
)
else
:
a_scale_m_stride
=
0
a_scale_k_stride
=
0
else
:
a_scale_m_stride
=
0
a_scale_k_stride
=
0
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
use_gdc
=
False
# supports_pdl(inputs.device)
_lora_shrink_kernel_fp8
[
grid
](
inputs
,
lora_ptr_tensor
,
output_tensor
,
a_scale_ptr
,
b_scale_ptr_tensor
,
M
,
N
,
K
,
token_indices_sorted_by_lora_ids
,
num_tokens_per_lora
,
lora_token_start_loc
,
lora_ids
,
scaling
,
inputs
.
stride
(
0
),
inputs
.
stride
(
1
),
lora_strides_d0
,
lora_strides_d1
,
lora_strides_d2
,
a_scale_m_stride
,
a_scale_k_stride
,
b_scale_l_stride
,
b_scale_n_stride
,
b_scale_k_stride
,
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
output_tensor
.
stride
(
2
),
group_n
,
group_k
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
SPLIT_K
,
GROUP_SIZE_M
,
NUM_SLICES
,
use_gdc
,
use_fp8_w8a8
,
per_channel_quant
,
use_gdc
,
num_warps
=
NUM_WARPS
,
num_ctas
=
NUM_CTAS
,
num_stages
=
NUM_STAGES
,
)
return
def
_lora_shrink_fp8_fake
(
inputs
:
torch
.
Tensor
,
lora_a_weights
:
list
[
torch
.
Tensor
],
output_tensor
:
torch
.
Tensor
,
token_lora_mapping
:
torch
.
Tensor
,
token_indices_sorted_by_lora_ids
:
torch
.
Tensor
,
num_tokens_per_lora
:
torch
.
Tensor
,
lora_token_start_loc
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
no_lora_flag_cpu
:
torch
.
Tensor
,
num_active_loras
:
int
,
scaling
:
float
,
b_scale
:
list
[
torch
.
Tensor
],
# LoRA weight scale per slice
a_scale
:
torch
.
Tensor
|
None
=
None
,
# Activation scale - per-token or block-wise
group_k
:
int
=
0
,
# Block size for K in block-wise quantization (0 = tensor-wise)
group_n
:
int
=
0
,
# Block size for N in block-wise quantization
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
)
->
None
:
return
try
:
direct_register_custom_op
(
op_name
=
"lora_shrink_fp8"
,
op_func
=
_lora_shrink_fp8
,
mutates_args
=
[
"output_tensor"
],
fake_impl
=
_lora_shrink_fp8_fake
,
)
lora_shrink_fp8
=
torch
.
ops
.
vllm
.
lora_shrink_fp8
except
AttributeError
:
lora_shrink_fp8
=
_lora_shrink_fp8
vllm/lora/ops/triton_ops/utils.py
View file @
b3ce711b
...
...
@@ -252,7 +252,7 @@ def get_lora_op_configs(
default
=
{
"block_m"
:
64
,
"block_n"
:
64
if
num_slices
>
1
else
128
,
"block_k"
:
16
,
"block_k"
:
32
,
"num_warps"
:
4
,
"num_ctas"
:
1
,
"num_stages"
:
2
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment