Unverified Commit b3ce711b authored by yugong333's avatar yugong333 Committed by GitHub
Browse files

Fp8 lora dense kernel (#35242)


Signed-off-by: default avatarYu Gong <yu3.gong@gmail.com>
parent abf61aaa
This diff is collapsed.
...@@ -12,13 +12,17 @@ from vllm.lora.ops.triton_ops.fused_moe_lora_op import ( ...@@ -12,13 +12,17 @@ from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
fused_moe_lora_expand, fused_moe_lora_expand,
fused_moe_lora_shrink, fused_moe_lora_shrink,
) )
from vllm.lora.ops.triton_ops.lora_expand_fp8_op import lora_expand_fp8
from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
from vllm.lora.ops.triton_ops.lora_shrink_fp8_op import lora_shrink_fp8
from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink
__all__ = [ __all__ = [
"lora_expand", "lora_expand",
"lora_expand_fp8",
"lora_shrink", "lora_shrink",
"lora_shrink_fp8",
"LoRAKernelMeta", "LoRAKernelMeta",
"fused_moe_lora", "fused_moe_lora",
"fused_moe_lora_shrink", "fused_moe_lora_shrink",
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
from vllm.lora.ops.triton_ops.fp8_kernel_utils import do_expand_kernel_fp8
from vllm.lora.ops.triton_ops.utils import (
_get_lora_b_ptr,
get_lora_op_configs,
)
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
_EXPAND_LORA_SCALE_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
def _get_expand_lora_scale_ptr(lora_weights: list[torch.Tensor], device: torch.device):
"""
`_EXPAND_LORA_SCALE_PTR_DICT` collects the required information during
`profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
if (ptr_tensor := _EXPAND_LORA_SCALE_PTR_DICT.get(key)) is not None:
return ptr_tensor
if len(lora_weights) > 1:
tensor_ptrs = []
for lora_weight in lora_weights:
tensor_ptrs.append(lora_weight.data_ptr())
ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
else:
# Single slice: return the actual tensor so the kernel can use it
# directly without pointer indirection (matches SLICE_NUM == 1 path).
ptr_tensor = lora_weights[0]
_EXPAND_LORA_SCALE_PTR_DICT[key] = ptr_tensor
return _EXPAND_LORA_SCALE_PTR_DICT.get(key)
@triton.jit
def _lora_expand_kernel_fp8(
input_ptr,
lora_ptr,
out_ptr,
a_scale_ptr,
b_scale_ptr,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_loc,
input_d0_stride,
input_d1_stride,
input_d2_stride,
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
a_scale_m_stride,
a_scale_k_stride,
b_scale_l_stride,
b_scale_n_stride,
b_scale_k_stride,
output_d0_stride,
output_d1_stride,
output_hs_ptr,
group_n: tl.constexpr,
group_k: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
SLICE_NUM: tl.constexpr,
SAME_STRIDE: tl.constexpr,
USE_GDC: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
per_channel_quant: tl.constexpr,
launch_pdl: tl.constexpr,
):
"""
FP8-compatible expand kernel wrapper.
"""
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_mn = tl.program_id(axis=0)
pid_m = pid_mn % cta_m_num
pid_n = (pid_mn // cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
return
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
if pid_n * BLOCK_N >= curr_N:
return
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
)
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_expand_kernel_fp8(
pid_n,
lora_id,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
a_scale_ptr,
b_scale_ptr,
curr_N,
K,
cta_m_len,
ram,
slice_start_loc,
input_d0_stride,
input_d1_stride,
input_d2_stride,
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
a_scale_m_stride,
a_scale_k_stride,
b_scale_l_stride,
b_scale_n_stride,
b_scale_k_stride,
output_d0_stride,
output_d1_stride,
group_n,
group_k,
BLOCK_M,
BLOCK_N,
BLOCK_K,
SAME_STRIDE,
SLICE_NUM,
EVEN_K,
CAST_TYPE,
ADD_INPUTS,
USE_GDC,
use_fp8_w8a8,
per_channel_quant,
)
@torch.inference_mode()
def _lora_expand_fp8(
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
lora_b_weights: list[torch.Tensor], # FP8 [num_lora, hidden_size, lora_rank]
output_tensor: torch.Tensor, # shape [num_tokens, hidden_size * num_slices]
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
b_scale: list[torch.Tensor], # LoRA B weight scale per slice
a_scale: torch.Tensor | None = None, # Scale for shrink output (optional)
offset_start: int = 0,
add_inputs: bool = False,
group_k: int = 0,
group_n: int = 0,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
) -> None:
"""
FP8-compatible LoRA expand operation.
Args:
inputs: Input tensor from shrink operation [num_slices, num_tokens, lora_rank]
lora_b_weights: List of FP8 LoRA B weights per slice
output_tensor: Output tensor
a_scale: Optional scale for input (if input is quantized)
b_scale: Weight quantization scales per slice
token_lora_mapping: Token to LoRA ID mapping
token_indices_sorted_by_lora_ids: Sorted token indices
num_tokens_per_lora: Number of tokens per LoRA
lora_token_start_loc: Start location for each LoRA's tokens
lora_ids: LoRA IDs to process
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
group_k (int, optional): Block size for K in block-wise quantization.
group_n (int, optional): Block size for N in block-wise quantization.
use_fp8_w8a8 (bool, optional): Whether to use FP8 W8A8 quantization.
per_channel_quant (bool, optional): Whether to use per-channel quantization.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
if use_fp8_w8a8:
assert inputs.dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]
for weight in lora_b_weights:
assert weight.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
]
else:
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(0) == len(lora_b_weights)
assert output_tensor.is_contiguous()
# metadata sanity check.
M = inputs.size(1)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(
slice_start_tensor,
lora_ptr_tensor,
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
hidden_sizes_tensor,
same_stride,
MAX_N,
) = _get_lora_b_ptr(lora_b_weights, offset_start, inputs.device)
# Get scale pointers
if b_scale is not None:
b_scale_ptr_tensor = _get_expand_lora_scale_ptr(b_scale, inputs.device)
else:
b_scale_ptr_tensor = None
K = lora_b_weights[0].shape[-1]
ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False
NUM_SLICES = len(lora_b_weights)
# Triton kernel configs.
kernel_config = get_lora_op_configs(
op_type="expand",
max_loras=MAX_LORAS,
batch=M,
hidden_size=MAX_N,
rank=K,
num_slices=NUM_SLICES,
add_inputs=add_inputs,
)
BLOCK_M = kernel_config["block_m"]
BLOCK_N = kernel_config["block_n"]
BLOCK_K = kernel_config["block_k"]
NUM_WARPS = kernel_config["num_warps"]
NUM_CTAS = kernel_config.get("num_ctas", 1)
NUM_STAGES = kernel_config["num_stages"]
EVEN_K = K % BLOCK_K == 0
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
num_active_loras,
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
use_gdc = False # supports_pdl(inputs.device)
# Get scale strides
if a_scale is not None:
a_scale_m_stride = a_scale.stride(0) if a_scale.dim() > 1 else 0
a_scale_k_stride = a_scale.stride(-1) if a_scale.dim() > 1 else 0
else:
a_scale_m_stride = 0
a_scale_k_stride = 0
if b_scale is not None and b_scale[0].dim() > 0:
b_scale_l_stride = b_scale[0].stride(0) if b_scale[0].dim() > 0 else 0
b_scale_n_stride = (
b_scale[0].stride(-2)
if b_scale[0].dim() > 2
else (b_scale[0].stride(-1) if b_scale[0].dim() > 1 else 1)
)
b_scale_k_stride = b_scale[0].stride(-1) if b_scale[0].dim() > 2 else 0
else:
b_scale_l_stride = 1
b_scale_n_stride = 0
b_scale_k_stride = 0
_lora_expand_kernel_fp8[grid](
inputs,
lora_ptr_tensor,
output_tensor,
a_scale,
b_scale_ptr_tensor,
M,
MAX_N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_tensor,
inputs.stride(0),
inputs.stride(1),
inputs.stride(2),
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
a_scale_m_stride,
a_scale_k_stride,
b_scale_l_stride,
b_scale_n_stride,
b_scale_k_stride,
output_tensor.stride(0),
output_tensor.stride(1),
hidden_sizes_tensor,
group_n,
group_k,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
NUM_SLICES,
same_stride,
use_gdc,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
launch_pdl=use_gdc,
)
return
def _lora_expand_fp8_fake(
inputs: torch.Tensor,
lora_b_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
b_scale: list[torch.Tensor],
a_scale: torch.Tensor | None = None,
offset_start: int = 0,
add_inputs: bool = False,
group_k: int = 0,
group_n: int = 0,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="lora_expand_fp8",
op_func=_lora_expand_fp8,
mutates_args=["output_tensor"],
fake_impl=_lora_expand_fp8_fake,
)
lora_expand_fp8 = torch.ops.vllm.lora_expand_fp8
except AttributeError:
lora_expand_fp8 = _lora_expand_fp8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
from vllm.lora.ops.triton_ops.fp8_kernel_utils import do_shrink_kernel_fp8
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
_SHRINK_LORA_SCALE_PTR_DICT: dict[tuple[int, ...], tuple] = {}
def _get_shrink_lora_scale_ptr(
lora_scale_weights: list[torch.Tensor], device: torch.device
):
"""
`_SHRINK_LORA_SCALE_PTR_DICT` collects the required information during
`profile_run`. After this, it remains constant and subsequent usage is
through LUT.
Returns a tuple of (scale_ptr_tensor, l_stride, n_stride, k_stride).
Supports scale tensors of varying dimensionality:
- 1D: (lora_num,) — tensor-wise quantization
- 2D: (lora_num, N) — per-channel quantization
- 3D: (lora_num, N, K) — block-wise quantization
- 4D: (lora_num, 1, N, K) — block-wise with extra dim (squeezed to 3D)
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_scale_weights)
if values := _SHRINK_LORA_SCALE_PTR_DICT.get(key):
return values
tensor_ptrs = []
scale_l_strides = []
scale_n_strides = []
scale_k_strides = []
for lora_scale_weight in lora_scale_weights:
if lora_scale_weight.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_scale_weight.size(1) == 1
lora_scale_weight = lora_scale_weight.squeeze(dim=1)
assert 1 <= lora_scale_weight.ndim <= 3
assert lora_scale_weight.is_contiguous()
tensor_ptrs.append(lora_scale_weight.data_ptr())
scale_l_strides.append(
lora_scale_weight.stride(0) if lora_scale_weight.ndim > 0 else 0
)
scale_n_strides.append(
lora_scale_weight.stride(-2)
if lora_scale_weight.ndim > 2
else (lora_scale_weight.stride(-1) if lora_scale_weight.ndim > 1 else 1)
)
scale_k_strides.append(
lora_scale_weight.stride(-1) if lora_scale_weight.ndim > 2 else 0
)
if len(lora_scale_weights) > 1:
scale_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
else:
scale_ptr_tensor = lora_scale_weights[0]
if (
len(set(scale_l_strides)) > 1
or len(set(scale_n_strides)) > 1
or len(set(scale_k_strides)) > 1
):
raise ValueError("All LoRA scale weights must have the same stride.")
_SHRINK_LORA_SCALE_PTR_DICT[key] = (
scale_ptr_tensor,
scale_l_strides[0],
scale_n_strides[0],
scale_k_strides[0],
)
return _SHRINK_LORA_SCALE_PTR_DICT.get(key)
@triton.jit
def _lora_shrink_kernel_fp8(
input_ptr,
lora_ptr,
out_ptr,
a_scale_ptr,
b_scale_ptr,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
scaling,
input_d0_stride,
input_d1_stride,
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
a_scale_m_stride,
a_scale_k_stride,
b_scale_l_stride,
b_scale_n_stride,
b_scale_k_stride,
output_d0_stride,
output_d1_stride,
output_d2_stride,
group_n: tl.constexpr,
group_k: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SLICE_NUM: tl.constexpr,
USE_GDC: tl.constexpr, ## should always be false in shrink kernel
use_fp8_w8a8: tl.constexpr,
per_channel_quant: tl.constexpr,
launch_pdl: tl.constexpr,
):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_sk_m_n = tl.program_id(axis=0)
pid_sk = pid_sk_m_n % SPLIT_K
pid_m_n = pid_sk_m_n // SPLIT_K
num_pid_in_group = GROUP_SIZE_M * cta_n_num
group_id = pid_m_n // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(cta_m_num - first_pid_m, GROUP_SIZE_M)
# Column-major ordering within groups for better cache reuse
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_shrink_kernel_fp8(
pid_n,
pid_sk,
slice_id,
lora_id,
input_ptr,
lora_ptr,
out_ptr,
a_scale_ptr,
b_scale_ptr,
N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# scale strides
a_scale_m_stride,
a_scale_k_stride,
b_scale_l_stride,
b_scale_n_stride,
b_scale_k_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
# block size for block-wise quantization
group_n,
group_k,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
SLICE_NUM,
USE_GDC,
use_fp8_w8a8,
per_channel_quant,
launch_pdl,
)
@torch.inference_mode()
def _lora_shrink_fp8(
inputs: torch.Tensor, # shape [num_tokens, hidden_size] - FP8 or FP16/BF16
lora_a_weights: list[
torch.Tensor
], # shape [num_loras, lora_rank, hidden_size] - FP8 or FP16/BF16
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
scaling: float,
b_scale: list[torch.Tensor], # LoRA weight scale per slice
a_scale: torch.Tensor | None = None, # Activation scale - per-token or block-wise
group_k: int = 0, # Block size for K in block-wise quantization (0 = tensor-wise)
group_n: int = 0, # Block size for N in block-wise quantization
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
) -> None:
"""
Args:
inputs: FP8 or FP16/BF16 input tensor [num_tokens, hidden_size]
lora_a_weights: List of FP8 or FP16/BF16 LoRA A weights per slice
output_tensor: Output tensor (FP16/BF16/FP32)
token_lora_mapping: Token to LoRA ID mapping
token_indices_sorted_by_lora_ids: Sorted token indices
num_tokens_per_lora: Number of tokens per LoRA
lora_token_start_loc: Start location for each LoRA's tokens
lora_ids: LoRA IDs to process
scaling: LoRA scaling factor
a_scale: Activation quantization scales
b_scale: Weight quantization scales per slice
group_k: Block size for K dimension quantization
group_n: Block size for N dimension quantization
use_fp8_w8a8: Whether to use FP8 weights and activations
per_channel_quant: Whether to use per-channel quantization
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.size(1) == lora_a_weights[0].size(-1)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
# metadata sanity check
M = inputs.size(0)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
output_tensor.zero_()
# Get LoRA weight pointers
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = (
_get_lora_a_ptr(lora_a_weights, inputs.device)
)
# Get scale pointers if using FP8
if use_fp8_w8a8:
assert a_scale is not None, "a_scale required for FP8 w8a8"
assert b_scale is not None, "b_scale required for FP8"
b_scale_ptr_tensor, b_scale_l_stride, b_scale_n_stride, b_scale_k_stride = (
_get_shrink_lora_scale_ptr(b_scale, inputs.device)
)
a_scale_ptr = (
a_scale if a_scale is not None else torch.tensor(1.0, device=inputs.device)
)
else:
b_scale_ptr_tensor = torch.tensor(0, device=inputs.device)
b_scale_l_stride = 0
b_scale_n_stride = 0
b_scale_k_stride = 0
a_scale_ptr = torch.tensor(0, device=inputs.device)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size, N=rank
NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0)
# Triton kernel configs
kernel_config = get_lora_op_configs(
"shrink",
max_loras=MAX_LORAS,
batch=M,
hidden_size=K,
rank=N,
num_slices=NUM_SLICES,
)
BLOCK_M = kernel_config["block_m"]
BLOCK_N = kernel_config["block_n"]
BLOCK_K = kernel_config["block_k"]
SPLIT_K = kernel_config["split_k"]
NUM_WARPS = kernel_config["num_warps"]
NUM_STAGES = kernel_config["num_stages"]
NUM_CTAS = kernel_config["num_ctas"]
GROUP_SIZE_M = kernel_config.get("group_size_m", 8)
assert BLOCK_K is not None and SPLIT_K is not None
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
# Grid configuration with column-major ordering support
grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES,
num_active_loras,
)
# Determine scale strides
if use_fp8_w8a8:
if a_scale is not None and a_scale.ndim == 2:
a_scale_m_stride = a_scale.stride(0)
a_scale_k_stride = a_scale.stride(1)
else:
a_scale_m_stride = 0
a_scale_k_stride = 0
else:
a_scale_m_stride = 0
a_scale_k_stride = 0
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
use_gdc = False # supports_pdl(inputs.device)
_lora_shrink_kernel_fp8[grid](
inputs,
lora_ptr_tensor,
output_tensor,
a_scale_ptr,
b_scale_ptr_tensor,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_strides_d0,
lora_strides_d1,
lora_strides_d2,
a_scale_m_stride,
a_scale_k_stride,
b_scale_l_stride,
b_scale_n_stride,
b_scale_k_stride,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor.stride(2),
group_n,
group_k,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
GROUP_SIZE_M,
NUM_SLICES,
use_gdc,
use_fp8_w8a8,
per_channel_quant,
use_gdc,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
)
return
def _lora_shrink_fp8_fake(
inputs: torch.Tensor,
lora_a_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
scaling: float,
b_scale: list[torch.Tensor], # LoRA weight scale per slice
a_scale: torch.Tensor | None = None, # Activation scale - per-token or block-wise
group_k: int = 0, # Block size for K in block-wise quantization (0 = tensor-wise)
group_n: int = 0, # Block size for N in block-wise quantization
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="lora_shrink_fp8",
op_func=_lora_shrink_fp8,
mutates_args=["output_tensor"],
fake_impl=_lora_shrink_fp8_fake,
)
lora_shrink_fp8 = torch.ops.vllm.lora_shrink_fp8
except AttributeError:
lora_shrink_fp8 = _lora_shrink_fp8
...@@ -252,7 +252,7 @@ def get_lora_op_configs( ...@@ -252,7 +252,7 @@ def get_lora_op_configs(
default = { default = {
"block_m": 64, "block_m": 64,
"block_n": 64 if num_slices > 1 else 128, "block_n": 64 if num_slices > 1 else 128,
"block_k": 16, "block_k": 32,
"num_warps": 4, "num_warps": 4,
"num_ctas": 1, "num_ctas": 1,
"num_stages": 2, "num_stages": 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment