"vscode:/vscode.git/clone" did not exist on "2bf2566fed64ffff6ff878f252a5b29bfccbfced"
Unverified Commit ea4bf122 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Fix division-by-zero bug in LoRA triton kernels. (#7785)

parent a291439a
......@@ -31,28 +31,44 @@ def _gate_up_lora_b_kernel(
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# For fused output scaling and adding
fuse_scaling_add,
# For fused output scaling
scalings,
):
# This kernel packs 2 sgemms (gate/up) into a single kernel.
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
# weights: (num_lora, 2 * output_dim, K)
# output: (s, 2 * output_dim)
"""
This kernel packs 2 sgemms (gate/up) into a single kernel. The multiplication
results are accumulated into the output tensor.
When a sequence's rank is 0, the kernel is essentially a no-op, following
the convention in pytorch where the product of two matrices of shape (m, 0)
and (0, n) is an all-zero matrix of shape (m, n).
Args:
x (Tensor): The input tensor, which is the result of the LoRA A projection.
Shape: (s, 2 * K), where s is the sum of all sequence lengths in the
batch and K is the maximum LoRA rank.
weights (Tensor): The LoRA B weights for all adapters.
Shape: (num_lora, 2 * output_dim, K).
output (Tensor): The output tensor where the result is stored.
Shape: (s, 2 * output_dim).
"""
# output_dim >> K
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# gate_up_id decides which of gate or up (0: gate, 1: up)
batch_id = tl.program_id(axis=2)
w_index = tl.load(weight_indices + batch_id)
rank = tl.load(lora_ranks + w_index)
# If rank is 0, this kernel is a no-op.
if rank == 0:
return
gate_up_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
n_start = gate_up_id * output_dim # offset on output dim
rank = tl.load(lora_ranks + w_index)
scaling = tl.load(scalings + w_index)
# Adjust K (rank) according to the specific LoRA adapter
......@@ -82,14 +98,13 @@ def _gate_up_lora_b_kernel(
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K)
and (n_offset[None, :] < output_dim),
& (n_offset[None, :] < output_dim),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
......@@ -103,9 +118,8 @@ def _gate_up_lora_b_kernel(
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
if fuse_scaling_add:
partial_sum += tl.load(output_ptr, mask=output_mask)
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < output_dim)
partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask)
......@@ -143,11 +157,9 @@ def gate_up_lora_b_fwd(
)
if base_output is None:
output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
fuse_scaling_add = False
output = torch.zeros((s, 2 * output_dim), device=x.device, dtype=x.dtype)
else:
output = base_output
fuse_scaling_add = True
_gate_up_lora_b_kernel[grid_b](
x,
......@@ -169,7 +181,6 @@ def gate_up_lora_b_fwd(
BLOCK_S,
BLOCK_OUT,
BLOCK_R,
fuse_scaling_add,
batch_info.scalings,
)
......
......@@ -33,29 +33,45 @@ def _qkv_lora_b_kernel(
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# For fused output scaling and adding
fuse_scaling_add,
# For fused output scaling
scalings,
):
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
# weights: (num_lora, N_Q + 2 * N_KV, K)
# output: (s, N_Q + 2 * N_KV)
# N_Q >> K, N_KV >> K
"""
This kernel packs 3 sgemms (q/k/v) into a single kernel. The multiplication
results are accumulated into the output tensor.
When a sequence's rank is 0, the kernel is essentially a no-op, following
the convention in pytorch where the product of two matrices of shape (m, 0)
and (0, n) is an all-zero matrix of shape (m, n).
Args:
x (Tensor): The input tensor, which is the result of the LoRA A projection.
Shape: (s, 3 * K), where s is the sum of all sequence lengths in the
batch and K is the maximum LoRA rank. The second dimension is partitioned
for Q, K, and V.
weights (Tensor): The LoRA B weights for all adapters.
Shape: (num_lora, N_Q + 2 * N_KV, K).
output (Tensor): The output tensor where the result is stored.
Shape: (s, N_Q + 2 * N_KV).
"""
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
batch_id = tl.program_id(axis=2)
w_index = tl.load(weight_indices + batch_id)
rank = tl.load(lora_ranks + w_index)
# If rank is 0, this kernel is a no-op.
if rank == 0:
return
qkv_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
n_start = tl.load(n_offs + qkv_id)
n_size = tl.load(n_offs + qkv_id + 1) - n_start
rank = tl.load(lora_ranks + w_index)
scaling = tl.load(scalings + w_index)
# Adjust K (rank) according to the specific LoRA adapter
K = tl.minimum(K, rank)
......@@ -84,13 +100,12 @@ def _qkv_lora_b_kernel(
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < n_size),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
......@@ -105,8 +120,7 @@ def _qkv_lora_b_kernel(
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
if fuse_scaling_add:
partial_sum += tl.load(output_ptr, mask=output_mask)
partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask)
......@@ -153,11 +167,9 @@ def qkv_lora_b_fwd(
)
if base_output is None:
output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
fuse_scaling_add = False
output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype)
else:
output = base_output
fuse_scaling_add = True
_qkv_lora_b_kernel[grid_b](
x,
......@@ -180,7 +192,6 @@ def qkv_lora_b_fwd(
BLOCK_S,
BLOCK_OUT,
BLOCK_R,
fuse_scaling_add,
batch_info.scalings,
)
......
......@@ -33,19 +33,36 @@ def _sgemm_lora_a_kernel(
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# x: (s, K), s is the sum of sequence lengths
# weights: (num_lora, N, K)
# output: (s, N)
"""
Computes a segmented batched matrix multiplication for the LoRA A matrix.
The kernel ensures that output[seg_start:seg_start + seg_len, :rank * stack_num]
stores the product of the input `x` and the LoRA weights for the corresponding
sequence. This implies that when rank is 0, the kernel is essentially a no-op,
as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
Args:
x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
is the sum of all sequence lengths in the batch.
weights (torch.Tensor): The LoRA 'A' weights for all available adapters,
with shape `(num_lora, N, K)`.
output (torch.Tensor): The output tensor of shape `(s, N)`.
"""
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len
batch_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
rank = tl.load(lora_ranks + w_index)
# If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
if rank == 0:
return
pid = tl.program_id(axis=0)
seg_start = tl.load(seg_indptr + batch_id)
seg_len = tl.load(seg_lens + batch_id)
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
N = tl.minimum(N, rank * stack_num)
......@@ -72,13 +89,12 @@ def _sgemm_lora_a_kernel(
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < N),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
......@@ -91,7 +107,7 @@ def _sgemm_lora_a_kernel(
output_ptr = (output + seg_start * output_stride_0) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < N)
tl.store(output_ptr, partial_sum, mask=output_mask)
......
......@@ -31,22 +31,39 @@ def _sgemm_lora_b_kernel(
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# For fused output scaling and adding
fuse_scaling_add,
# For fused output scaling
scalings,
):
# x: (s, K), s is the sum of sequence lengths
# weights: (num_lora, N, K)
# output: (s, N)
"""
Computes a segmented batched matrix multiplication for the LoRA B matrix
and adds the result to the output in-place.
When a sequence's rank is 0, the kernel is essentially a no-op, following
the convention in pytorch where the product of two matrices of shape (m, 0)
and (0, n) is an all-zero matrix of shape (m, n).
Args:
x (torch.Tensor): The intermediate tensor from the LoRA 'A' multiplication,
of shape `(s, K)`, where `s` is the total number of tokens.
weights (torch.Tensor): The LoRA 'B' weights for all available adapters,
with shape `(num_lora, N, K)`.
output (torch.Tensor): The output tensor of shape `(s, N)`. This can be
the base model's output for a fused add operation.
"""
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len
batch_id = tl.program_id(axis=1)
w_index = tl.load(weight_indices + batch_id)
rank = tl.load(lora_ranks + w_index)
# If rank is 0, this kernel is a no-op.
if rank == 0:
return
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
rank = tl.load(lora_ranks + w_index)
scaling = tl.load(scalings + w_index)
# Adjust K (rank) according to the specific LoRA adapter
K = tl.minimum(K, rank)
......@@ -74,8 +91,7 @@ def _sgemm_lora_b_kernel(
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
......@@ -95,8 +111,7 @@ def _sgemm_lora_b_kernel(
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = s_offset[:, None] < seg_len
if fuse_scaling_add:
partial_sum += tl.load(output_ptr, mask=output_mask)
partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask)
......@@ -132,11 +147,9 @@ def sgemm_lora_b_fwd(
)
if base_output is None:
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
fuse_scaling_add = False
output = torch.zeros((S, N), device=x.device, dtype=x.dtype)
else:
output = base_output
fuse_scaling_add = True
_sgemm_lora_b_kernel[grid](
x,
......@@ -158,7 +171,6 @@ def sgemm_lora_b_fwd(
BLOCK_S,
BLOCK_N,
BLOCK_R,
fuse_scaling_add,
batch_info.scalings,
)
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment