Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ea4bf122
Unverified
Commit
ea4bf122
authored
Jul 06, 2025
by
Lifu Huang
Committed by
GitHub
Jul 06, 2025
Browse files
Fix division-by-zero bug in LoRA triton kernels. (#7785)
parent
a291439a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
114 additions
and
64 deletions
+114
-64
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
+30
-19
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
+30
-19
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
+27
-11
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
+27
-15
No files found.
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
View file @
ea4bf122
...
@@ -31,28 +31,44 @@ def _gate_up_lora_b_kernel(
...
@@ -31,28 +31,44 @@ def _gate_up_lora_b_kernel(
BLOCK_S
:
tl
.
constexpr
,
BLOCK_S
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# For fused output scaling and adding
# For fused output scaling
fuse_scaling_add
,
scalings
,
scalings
,
):
):
# This kernel packs 2 sgemms (gate/up) into a single kernel.
"""
This kernel packs 2 sgemms (gate/up) into a single kernel. The multiplication
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
results are accumulated into the output tensor.
# weights: (num_lora, 2 * output_dim, K)
# output: (s, 2 * output_dim)
When a sequence's rank is 0, the kernel is essentially a no-op, following
the convention in pytorch where the product of two matrices of shape (m, 0)
and (0, n) is an all-zero matrix of shape (m, n).
Args:
x (Tensor): The input tensor, which is the result of the LoRA A projection.
Shape: (s, 2 * K), where s is the sum of all sequence lengths in the
batch and K is the maximum LoRA rank.
weights (Tensor): The LoRA B weights for all adapters.
Shape: (num_lora, 2 * output_dim, K).
output (Tensor): The output tensor where the result is stored.
Shape: (s, 2 * output_dim).
"""
# output_dim >> K
# output_dim >> K
# Current block computes sequence with batch_id,
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# which starts from row seg_start of x with length seg_len.
# gate_up_id decides which of gate or up (0: gate, 1: up)
# gate_up_id decides which of gate or up (0: gate, 1: up)
batch_id
=
tl
.
program_id
(
axis
=
2
)
batch_id
=
tl
.
program_id
(
axis
=
2
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
rank
=
tl
.
load
(
lora_ranks
+
w_index
)
# If rank is 0, this kernel is a no-op.
if
rank
==
0
:
return
gate_up_id
=
tl
.
program_id
(
axis
=
1
)
gate_up_id
=
tl
.
program_id
(
axis
=
1
)
pid
=
tl
.
program_id
(
axis
=
0
)
pid
=
tl
.
program_id
(
axis
=
0
)
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
n_start
=
gate_up_id
*
output_dim
# offset on output dim
n_start
=
gate_up_id
*
output_dim
# offset on output dim
rank
=
tl
.
load
(
lora_ranks
+
w_index
)
scaling
=
tl
.
load
(
scalings
+
w_index
)
scaling
=
tl
.
load
(
scalings
+
w_index
)
# Adjust K (rank) according to the specific LoRA adapter
# Adjust K (rank) according to the specific LoRA adapter
...
@@ -82,14 +98,13 @@ def _gate_up_lora_b_kernel(
...
@@ -82,14 +98,13 @@ def _gate_up_lora_b_kernel(
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
x_tile
=
tl
.
load
(
x_tile
=
tl
.
load
(
x_ptrs
,
x_ptrs
,
mask
=
(
s_offset
[:,
None
]
<
seg_len
)
mask
=
(
s_offset
[:,
None
]
<
seg_len
)
&
(
k_offset
[
None
,
:]
<
K
-
k
*
BLOCK_K
),
and
(
k_offset
[
None
,
:]
<
K
-
k
*
BLOCK_K
),
other
=
0.0
,
other
=
0.0
,
)
)
w_tile
=
tl
.
load
(
w_tile
=
tl
.
load
(
w_ptrs
,
w_ptrs
,
mask
=
(
k_offset
[:,
None
]
<
K
-
k
*
BLOCK_K
)
mask
=
(
k_offset
[:,
None
]
<
K
-
k
*
BLOCK_K
)
and
(
n_offset
[
None
,
:]
<
output_dim
),
&
(
n_offset
[
None
,
:]
<
output_dim
),
other
=
0.0
,
other
=
0.0
,
)
)
partial_sum
+=
tl
.
dot
(
x_tile
,
w_tile
)
partial_sum
+=
tl
.
dot
(
x_tile
,
w_tile
)
...
@@ -103,9 +118,8 @@ def _gate_up_lora_b_kernel(
...
@@ -103,9 +118,8 @@ def _gate_up_lora_b_kernel(
output_ptr
=
(
output
+
seg_start
*
output_stride_0
+
n_start
*
output_stride_1
)
+
(
output_ptr
=
(
output
+
seg_start
*
output_stride_0
+
n_start
*
output_stride_1
)
+
(
s_offset
[:,
None
]
*
output_stride_0
+
n_offset
[
None
,
:]
*
output_stride_1
s_offset
[:,
None
]
*
output_stride_0
+
n_offset
[
None
,
:]
*
output_stride_1
)
)
output_mask
=
(
s_offset
[:,
None
]
<
seg_len
)
and
(
n_offset
[
None
,
:]
<
output_dim
)
output_mask
=
(
s_offset
[:,
None
]
<
seg_len
)
&
(
n_offset
[
None
,
:]
<
output_dim
)
if
fuse_scaling_add
:
partial_sum
+=
tl
.
load
(
output_ptr
,
mask
=
output_mask
)
partial_sum
+=
tl
.
load
(
output_ptr
,
mask
=
output_mask
)
tl
.
store
(
output_ptr
,
partial_sum
,
mask
=
output_mask
)
tl
.
store
(
output_ptr
,
partial_sum
,
mask
=
output_mask
)
...
@@ -143,11 +157,9 @@ def gate_up_lora_b_fwd(
...
@@ -143,11 +157,9 @@ def gate_up_lora_b_fwd(
)
)
if
base_output
is
None
:
if
base_output
is
None
:
output
=
torch
.
empty
((
s
,
2
*
output_dim
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
output
=
torch
.
zeros
((
s
,
2
*
output_dim
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
fuse_scaling_add
=
False
else
:
else
:
output
=
base_output
output
=
base_output
fuse_scaling_add
=
True
_gate_up_lora_b_kernel
[
grid_b
](
_gate_up_lora_b_kernel
[
grid_b
](
x
,
x
,
...
@@ -169,7 +181,6 @@ def gate_up_lora_b_fwd(
...
@@ -169,7 +181,6 @@ def gate_up_lora_b_fwd(
BLOCK_S
,
BLOCK_S
,
BLOCK_OUT
,
BLOCK_OUT
,
BLOCK_R
,
BLOCK_R
,
fuse_scaling_add
,
batch_info
.
scalings
,
batch_info
.
scalings
,
)
)
...
...
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
View file @
ea4bf122
...
@@ -33,29 +33,45 @@ def _qkv_lora_b_kernel(
...
@@ -33,29 +33,45 @@ def _qkv_lora_b_kernel(
BLOCK_S
:
tl
.
constexpr
,
BLOCK_S
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# For fused output scaling and adding
# For fused output scaling
fuse_scaling_add
,
scalings
,
scalings
,
):
):
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
"""
This kernel packs 3 sgemms (q/k/v) into a single kernel. The multiplication
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
results are accumulated into the output tensor.
# weights: (num_lora, N_Q + 2 * N_KV, K)
# output: (s, N_Q + 2 * N_KV)
When a sequence's rank is 0, the kernel is essentially a no-op, following
# N_Q >> K, N_KV >> K
the convention in pytorch where the product of two matrices of shape (m, 0)
and (0, n) is an all-zero matrix of shape (m, n).
Args:
x (Tensor): The input tensor, which is the result of the LoRA A projection.
Shape: (s, 3 * K), where s is the sum of all sequence lengths in the
batch and K is the maximum LoRA rank. The second dimension is partitioned
for Q, K, and V.
weights (Tensor): The LoRA B weights for all adapters.
Shape: (num_lora, N_Q + 2 * N_KV, K).
output (Tensor): The output tensor where the result is stored.
Shape: (s, N_Q + 2 * N_KV).
"""
# Current block computes sequence with batch_id,
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# which starts from row seg_start of x with length seg_len.
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
batch_id
=
tl
.
program_id
(
axis
=
2
)
batch_id
=
tl
.
program_id
(
axis
=
2
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
rank
=
tl
.
load
(
lora_ranks
+
w_index
)
# If rank is 0, this kernel is a no-op.
if
rank
==
0
:
return
qkv_id
=
tl
.
program_id
(
axis
=
1
)
qkv_id
=
tl
.
program_id
(
axis
=
1
)
pid
=
tl
.
program_id
(
axis
=
0
)
pid
=
tl
.
program_id
(
axis
=
0
)
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
n_start
=
tl
.
load
(
n_offs
+
qkv_id
)
n_start
=
tl
.
load
(
n_offs
+
qkv_id
)
n_size
=
tl
.
load
(
n_offs
+
qkv_id
+
1
)
-
n_start
n_size
=
tl
.
load
(
n_offs
+
qkv_id
+
1
)
-
n_start
rank
=
tl
.
load
(
lora_ranks
+
w_index
)
scaling
=
tl
.
load
(
scalings
+
w_index
)
scaling
=
tl
.
load
(
scalings
+
w_index
)
# Adjust K (rank) according to the specific LoRA adapter
# Adjust K (rank) according to the specific LoRA adapter
K
=
tl
.
minimum
(
K
,
rank
)
K
=
tl
.
minimum
(
K
,
rank
)
...
@@ -84,13 +100,12 @@ def _qkv_lora_b_kernel(
...
@@ -84,13 +100,12 @@ def _qkv_lora_b_kernel(
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
x_tile
=
tl
.
load
(
x_tile
=
tl
.
load
(
x_ptrs
,
x_ptrs
,
mask
=
(
s_offset
[:,
None
]
<
seg_len
)
mask
=
(
s_offset
[:,
None
]
<
seg_len
)
&
(
k_offset
[
None
,
:]
<
K
-
k
*
BLOCK_K
),
and
(
k_offset
[
None
,
:]
<
K
-
k
*
BLOCK_K
),
other
=
0.0
,
other
=
0.0
,
)
)
w_tile
=
tl
.
load
(
w_tile
=
tl
.
load
(
w_ptrs
,
w_ptrs
,
mask
=
(
k_offset
[:,
None
]
<
K
-
k
*
BLOCK_K
)
and
(
n_offset
[
None
,
:]
<
n_size
),
mask
=
(
k_offset
[:,
None
]
<
K
-
k
*
BLOCK_K
)
&
(
n_offset
[
None
,
:]
<
n_size
),
other
=
0.0
,
other
=
0.0
,
)
)
partial_sum
+=
tl
.
dot
(
x_tile
,
w_tile
)
partial_sum
+=
tl
.
dot
(
x_tile
,
w_tile
)
...
@@ -105,8 +120,7 @@ def _qkv_lora_b_kernel(
...
@@ -105,8 +120,7 @@ def _qkv_lora_b_kernel(
s_offset
[:,
None
]
*
output_stride_0
+
n_offset
[
None
,
:]
*
output_stride_1
s_offset
[:,
None
]
*
output_stride_0
+
n_offset
[
None
,
:]
*
output_stride_1
)
)
output_mask
=
(
s_offset
[:,
None
]
<
seg_len
)
and
(
n_offset
[
None
,
:]
<
n_size
)
output_mask
=
(
s_offset
[:,
None
]
<
seg_len
)
and
(
n_offset
[
None
,
:]
<
n_size
)
if
fuse_scaling_add
:
partial_sum
+=
tl
.
load
(
output_ptr
,
mask
=
output_mask
)
partial_sum
+=
tl
.
load
(
output_ptr
,
mask
=
output_mask
)
tl
.
store
(
output_ptr
,
partial_sum
,
mask
=
output_mask
)
tl
.
store
(
output_ptr
,
partial_sum
,
mask
=
output_mask
)
...
@@ -153,11 +167,9 @@ def qkv_lora_b_fwd(
...
@@ -153,11 +167,9 @@ def qkv_lora_b_fwd(
)
)
if
base_output
is
None
:
if
base_output
is
None
:
output
=
torch
.
empty
((
s
,
output_dim
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
output
=
torch
.
zeros
((
s
,
output_dim
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
fuse_scaling_add
=
False
else
:
else
:
output
=
base_output
output
=
base_output
fuse_scaling_add
=
True
_qkv_lora_b_kernel
[
grid_b
](
_qkv_lora_b_kernel
[
grid_b
](
x
,
x
,
...
@@ -180,7 +192,6 @@ def qkv_lora_b_fwd(
...
@@ -180,7 +192,6 @@ def qkv_lora_b_fwd(
BLOCK_S
,
BLOCK_S
,
BLOCK_OUT
,
BLOCK_OUT
,
BLOCK_R
,
BLOCK_R
,
fuse_scaling_add
,
batch_info
.
scalings
,
batch_info
.
scalings
,
)
)
...
...
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
View file @
ea4bf122
...
@@ -33,19 +33,36 @@ def _sgemm_lora_a_kernel(
...
@@ -33,19 +33,36 @@ def _sgemm_lora_a_kernel(
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
):
):
"""
# x: (s, K), s is the sum of sequence lengths
Computes a segmented batched matrix multiplication for the LoRA A matrix.
# weights: (num_lora, N, K)
# output: (s, N)
The kernel ensures that output[seg_start:seg_start + seg_len, :rank * stack_num]
stores the product of the input `x` and the LoRA weights for the corresponding
sequence. This implies that when rank is 0, the kernel is essentially a no-op,
as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
Args:
x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
is the sum of all sequence lengths in the batch.
weights (torch.Tensor): The LoRA 'A' weights for all available adapters,
with shape `(num_lora, N, K)`.
output (torch.Tensor): The output tensor of shape `(s, N)`.
"""
# Current block computes sequence with batch_id,
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len
# which starts from row seg_start of x with length seg_len
batch_id
=
tl
.
program_id
(
axis
=
1
)
batch_id
=
tl
.
program_id
(
axis
=
1
)
pid
=
tl
.
program_id
(
axis
=
0
)
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
rank
=
tl
.
load
(
lora_ranks
+
w_index
)
rank
=
tl
.
load
(
lora_ranks
+
w_index
)
# If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
if
rank
==
0
:
return
pid
=
tl
.
program_id
(
axis
=
0
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
N
=
tl
.
minimum
(
N
,
rank
*
stack_num
)
N
=
tl
.
minimum
(
N
,
rank
*
stack_num
)
...
@@ -72,13 +89,12 @@ def _sgemm_lora_a_kernel(
...
@@ -72,13 +89,12 @@ def _sgemm_lora_a_kernel(
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
x_tile
=
tl
.
load
(
x_tile
=
tl
.
load
(
x_ptrs
,
x_ptrs
,
mask
=
(
s_offset
[:,
None
]
<
seg_len
)
mask
=
(
s_offset
[:,
None
]
<
seg_len
)
&
(
k_offset
[
None
,
:]
<
K
-
k
*
BLOCK_K
),
and
(
k_offset
[
None
,
:]
<
K
-
k
*
BLOCK_K
),
other
=
0.0
,
other
=
0.0
,
)
)
w_tile
=
tl
.
load
(
w_tile
=
tl
.
load
(
w_ptrs
,
w_ptrs
,
mask
=
(
k_offset
[:,
None
]
<
K
-
k
*
BLOCK_K
)
and
(
n_offset
[
None
,
:]
<
N
),
mask
=
(
k_offset
[:,
None
]
<
K
-
k
*
BLOCK_K
)
&
(
n_offset
[
None
,
:]
<
N
),
other
=
0.0
,
other
=
0.0
,
)
)
partial_sum
+=
tl
.
dot
(
x_tile
,
w_tile
)
partial_sum
+=
tl
.
dot
(
x_tile
,
w_tile
)
...
@@ -91,7 +107,7 @@ def _sgemm_lora_a_kernel(
...
@@ -91,7 +107,7 @@ def _sgemm_lora_a_kernel(
output_ptr
=
(
output
+
seg_start
*
output_stride_0
)
+
(
output_ptr
=
(
output
+
seg_start
*
output_stride_0
)
+
(
s_offset
[:,
None
]
*
output_stride_0
+
n_offset
[
None
,
:]
*
output_stride_1
s_offset
[:,
None
]
*
output_stride_0
+
n_offset
[
None
,
:]
*
output_stride_1
)
)
output_mask
=
(
s_offset
[:,
None
]
<
seg_len
)
and
(
n_offset
[
None
,
:]
<
N
)
output_mask
=
(
s_offset
[:,
None
]
<
seg_len
)
&
(
n_offset
[
None
,
:]
<
N
)
tl
.
store
(
output_ptr
,
partial_sum
,
mask
=
output_mask
)
tl
.
store
(
output_ptr
,
partial_sum
,
mask
=
output_mask
)
...
...
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
View file @
ea4bf122
...
@@ -31,22 +31,39 @@ def _sgemm_lora_b_kernel(
...
@@ -31,22 +31,39 @@ def _sgemm_lora_b_kernel(
BLOCK_S
:
tl
.
constexpr
,
BLOCK_S
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# For fused output scaling and adding
# For fused output scaling
fuse_scaling_add
,
scalings
,
scalings
,
):
):
# x: (s, K), s is the sum of sequence lengths
"""
# weights: (num_lora, N, K)
Computes a segmented batched matrix multiplication for the LoRA B matrix
# output: (s, N)
and adds the result to the output in-place.
When a sequence's rank is 0, the kernel is essentially a no-op, following
the convention in pytorch where the product of two matrices of shape (m, 0)
and (0, n) is an all-zero matrix of shape (m, n).
Args:
x (torch.Tensor): The intermediate tensor from the LoRA 'A' multiplication,
of shape `(s, K)`, where `s` is the total number of tokens.
weights (torch.Tensor): The LoRA 'B' weights for all available adapters,
with shape `(num_lora, N, K)`.
output (torch.Tensor): The output tensor of shape `(s, N)`. This can be
the base model's output for a fused add operation.
"""
# Current block computes sequence with batch_id,
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len
# which starts from row seg_start of x with length seg_len
batch_id
=
tl
.
program_id
(
axis
=
1
)
batch_id
=
tl
.
program_id
(
axis
=
1
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
rank
=
tl
.
load
(
lora_ranks
+
w_index
)
# If rank is 0, this kernel is a no-op.
if
rank
==
0
:
return
pid
=
tl
.
program_id
(
axis
=
0
)
pid
=
tl
.
program_id
(
axis
=
0
)
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
rank
=
tl
.
load
(
lora_ranks
+
w_index
)
scaling
=
tl
.
load
(
scalings
+
w_index
)
scaling
=
tl
.
load
(
scalings
+
w_index
)
# Adjust K (rank) according to the specific LoRA adapter
# Adjust K (rank) according to the specific LoRA adapter
K
=
tl
.
minimum
(
K
,
rank
)
K
=
tl
.
minimum
(
K
,
rank
)
...
@@ -74,8 +91,7 @@ def _sgemm_lora_b_kernel(
...
@@ -74,8 +91,7 @@ def _sgemm_lora_b_kernel(
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
x_tile
=
tl
.
load
(
x_tile
=
tl
.
load
(
x_ptrs
,
x_ptrs
,
mask
=
(
s_offset
[:,
None
]
<
seg_len
)
mask
=
(
s_offset
[:,
None
]
<
seg_len
)
&
(
k_offset
[
None
,
:]
<
K
-
k
*
BLOCK_K
),
and
(
k_offset
[
None
,
:]
<
K
-
k
*
BLOCK_K
),
other
=
0.0
,
other
=
0.0
,
)
)
w_tile
=
tl
.
load
(
w_tile
=
tl
.
load
(
...
@@ -95,8 +111,7 @@ def _sgemm_lora_b_kernel(
...
@@ -95,8 +111,7 @@ def _sgemm_lora_b_kernel(
s_offset
[:,
None
]
*
output_stride_0
+
n_offset
[
None
,
:]
*
output_stride_1
s_offset
[:,
None
]
*
output_stride_0
+
n_offset
[
None
,
:]
*
output_stride_1
)
)
output_mask
=
s_offset
[:,
None
]
<
seg_len
output_mask
=
s_offset
[:,
None
]
<
seg_len
if
fuse_scaling_add
:
partial_sum
+=
tl
.
load
(
output_ptr
,
mask
=
output_mask
)
partial_sum
+=
tl
.
load
(
output_ptr
,
mask
=
output_mask
)
tl
.
store
(
output_ptr
,
partial_sum
,
mask
=
output_mask
)
tl
.
store
(
output_ptr
,
partial_sum
,
mask
=
output_mask
)
...
@@ -132,11 +147,9 @@ def sgemm_lora_b_fwd(
...
@@ -132,11 +147,9 @@ def sgemm_lora_b_fwd(
)
)
if
base_output
is
None
:
if
base_output
is
None
:
output
=
torch
.
empty
((
S
,
N
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
output
=
torch
.
zeros
((
S
,
N
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
fuse_scaling_add
=
False
else
:
else
:
output
=
base_output
output
=
base_output
fuse_scaling_add
=
True
_sgemm_lora_b_kernel
[
grid
](
_sgemm_lora_b_kernel
[
grid
](
x
,
x
,
...
@@ -158,7 +171,6 @@ def sgemm_lora_b_fwd(
...
@@ -158,7 +171,6 @@ def sgemm_lora_b_fwd(
BLOCK_S
,
BLOCK_S
,
BLOCK_N
,
BLOCK_N
,
BLOCK_R
,
BLOCK_R
,
fuse_scaling_add
,
batch_info
.
scalings
,
batch_info
.
scalings
,
)
)
return
output
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment