Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ef9a378a
Unverified
Commit
ef9a378a
authored
Mar 29, 2025
by
chaobo jia
Committed by
GitHub
Mar 28, 2025
Browse files
[Feature] add multi-rank support for Lora (#4492)
Co-authored-by:
rudy152
<
czh1137892874@gmail.com
>
parent
6dea5c96
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
292 additions
and
97 deletions
+292
-97
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+23
-3
python/sglang/srt/lora/backend/base_backend.py
python/sglang/srt/lora/backend/base_backend.py
+4
-4
python/sglang/srt/lora/backend/flashinfer_backend.py
python/sglang/srt/lora/backend/flashinfer_backend.py
+12
-9
python/sglang/srt/lora/backend/triton_backend.py
python/sglang/srt/lora/backend/triton_backend.py
+5
-8
python/sglang/srt/lora/layers.py
python/sglang/srt/lora/layers.py
+19
-33
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+24
-7
python/sglang/srt/lora/mem_pool.py
python/sglang/srt/lora/mem_pool.py
+11
-5
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
+10
-4
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
+8
-3
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
+16
-5
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
+11
-6
python/sglang/srt/lora/utils.py
python/sglang/srt/lora/utils.py
+6
-0
test/srt/models/lora/test_lora.py
test/srt/models/lora/test_lora.py
+7
-5
test/srt/models/lora/test_multi_lora_backend.py
test/srt/models/lora/test_multi_lora_backend.py
+133
-2
test/srt/models/lora/utils.py
test/srt/models/lora/utils.py
+2
-2
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
No files found.
python/sglang/bench_serving.py
View file @
ef9a378a
...
@@ -965,7 +965,7 @@ async def benchmark(
...
@@ -965,7 +965,7 @@ async def benchmark(
request_rate
:
float
,
request_rate
:
float
,
max_concurrency
:
Optional
[
int
],
max_concurrency
:
Optional
[
int
],
disable_tqdm
:
bool
,
disable_tqdm
:
bool
,
lora_name
:
str
,
lora_name
s
:
List
[
str
]
,
extra_request_body
:
Dict
[
str
,
Any
],
extra_request_body
:
Dict
[
str
,
Any
],
profile
:
bool
,
profile
:
bool
,
pd_seperated
:
bool
=
False
,
pd_seperated
:
bool
=
False
,
...
@@ -988,6 +988,11 @@ async def benchmark(
...
@@ -988,6 +988,11 @@ async def benchmark(
# Warmup
# Warmup
print
(
"Starting initial single prompt test run..."
)
print
(
"Starting initial single prompt test run..."
)
test_prompt
,
test_prompt_len
,
test_output_len
=
input_requests
[
0
]
test_prompt
,
test_prompt_len
,
test_output_len
=
input_requests
[
0
]
if
lora_names
!=
None
and
len
(
lora_names
)
!=
0
:
lora_name
=
lora_names
[
0
]
else
:
lora_name
=
None
test_input
=
RequestFuncInput
(
test_input
=
RequestFuncInput
(
model
=
model_id
,
model
=
model_id
,
prompt
=
test_prompt
,
prompt
=
test_prompt
,
...
@@ -1028,6 +1033,12 @@ async def benchmark(
...
@@ -1028,6 +1033,12 @@ async def benchmark(
tasks
:
List
[
asyncio
.
Task
]
=
[]
tasks
:
List
[
asyncio
.
Task
]
=
[]
async
for
request
in
get_request
(
input_requests
,
request_rate
):
async
for
request
in
get_request
(
input_requests
,
request_rate
):
prompt
,
prompt_len
,
output_len
=
request
prompt
,
prompt_len
,
output_len
=
request
if
lora_names
!=
None
and
len
(
lora_names
)
!=
0
:
idx
=
random
.
randint
(
0
,
len
(
lora_names
)
-
1
)
lora_name
=
lora_names
[
idx
]
else
:
lora_name
=
None
request_func_input
=
RequestFuncInput
(
request_func_input
=
RequestFuncInput
(
model
=
model_id
,
model
=
model_id
,
prompt
=
prompt
,
prompt
=
prompt
,
...
@@ -1347,7 +1358,7 @@ def run_benchmark(args_: argparse.Namespace):
...
@@ -1347,7 +1358,7 @@ def run_benchmark(args_: argparse.Namespace):
request_rate
=
args
.
request_rate
,
request_rate
=
args
.
request_rate
,
max_concurrency
=
args
.
max_concurrency
,
max_concurrency
=
args
.
max_concurrency
,
disable_tqdm
=
args
.
disable_tqdm
,
disable_tqdm
=
args
.
disable_tqdm
,
lora_name
=
args
.
lora_name
,
lora_name
s
=
args
.
lora_name
,
extra_request_body
=
extra_request_body
,
extra_request_body
=
extra_request_body
,
profile
=
args
.
profile
,
profile
=
args
.
profile
,
pd_seperated
=
args
.
pd_seperated
,
pd_seperated
=
args
.
pd_seperated
,
...
@@ -1366,6 +1377,13 @@ def set_ulimit(target_soft_limit=65535):
...
@@ -1366,6 +1377,13 @@ def set_ulimit(target_soft_limit=65535):
print
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
print
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
class
LoRAPathAction
(
argparse
.
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
setattr
(
namespace
,
self
.
dest
,
[])
for
lora_name
in
values
:
getattr
(
namespace
,
self
.
dest
).
append
(
lora_name
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
description
=
"Benchmark the online serving throughput."
)
parser
=
ArgumentParser
(
description
=
"Benchmark the online serving throughput."
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -1509,8 +1527,10 @@ if __name__ == "__main__":
...
@@ -1509,8 +1527,10 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--lora-name"
,
"--lora-name"
,
type
=
str
,
type
=
str
,
nargs
=
"*"
,
default
=
None
,
default
=
None
,
help
=
"The name of LoRA adapter"
,
action
=
LoRAPathAction
,
help
=
"The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}..."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--prompt-suffix"
,
"--prompt-suffix"
,
...
...
python/sglang/srt/lora/backend/base_backend.py
View file @
ef9a378a
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.srt.lora.utils
import
LoRABatchInfo
def
get_fuse_output_
scaling_
add_from_name
(
name
:
str
)
->
bool
:
def
get_fuse_output_add_from_name
(
name
:
str
)
->
bool
:
mapping
=
{
mapping
=
{
"triton"
:
True
,
"triton"
:
True
,
"flashinfer"
:
False
,
"flashinfer"
:
False
,
...
@@ -28,14 +28,14 @@ class BaseLoRABackend:
...
@@ -28,14 +28,14 @@ class BaseLoRABackend:
Args:
Args:
name: name of backend
name: name of backend
batch_info: information of current batch for use
batch_info: information of current batch for use
fuse_output_
scaling_
add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
and the operation of
scaling and
adding will be fused into kernel
and the operation of adding will be fused into kernel
"""
"""
def
__init__
(
self
,
name
:
str
,
batch_info
:
LoRABatchInfo
=
None
):
def
__init__
(
self
,
name
:
str
,
batch_info
:
LoRABatchInfo
=
None
):
self
.
name
=
name
self
.
name
=
name
self
.
batch_info
=
batch_info
self
.
batch_info
=
batch_info
self
.
fuse_output_
scaling_
add
=
get_fuse_output_
scaling_
add_from_name
(
name
)
self
.
fuse_output_add
=
get_fuse_output_add_from_name
(
name
)
self
.
fuse_stacked_lora_b
=
get_fuse_stacked_lora_b_from_name
(
name
)
self
.
fuse_stacked_lora_b
=
get_fuse_stacked_lora_b_from_name
(
name
)
def
run_lora_a_sgemm
(
def
run_lora_a_sgemm
(
...
...
python/sglang/srt/lora/backend/flashinfer_backend.py
View file @
ef9a378a
...
@@ -37,13 +37,16 @@ class FlashInferLoRABackend(BaseLoRABackend):
...
@@ -37,13 +37,16 @@ class FlashInferLoRABackend(BaseLoRABackend):
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
segment_gemm
.
run
(
return
(
x
=
x
,
self
.
segment_gemm
.
run
(
weights
=
weights
,
x
=
x
,
batch_size
=
self
.
batch_info
.
bs
,
weights
=
weights
,
weight_column_major
=
True
,
batch_size
=
self
.
batch_info
.
bs
,
seg_indptr
=
self
.
batch_info
.
seg_indptr
,
weight_column_major
=
True
,
weight_indices
=
self
.
batch_info
.
weight_indices
,
seg_indptr
=
self
.
batch_info
.
seg_indptr
,
weight_indices
=
self
.
batch_info
.
weight_indices
,
)
*
self
.
batch_info
.
scalings
[
0
]
)
)
def
run_qkv_lora
(
def
run_qkv_lora
(
...
@@ -90,7 +93,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
...
@@ -90,7 +93,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
weights
=
kv_lora_b
[
1
],
weights
=
kv_lora_b
[
1
],
)
)
return
lora_output
return
lora_output
*
self
.
batch_info
.
scalings
[
0
]
def
run_gate_up_lora
(
def
run_gate_up_lora
(
self
,
self
,
...
@@ -125,4 +128,4 @@ class FlashInferLoRABackend(BaseLoRABackend):
...
@@ -125,4 +128,4 @@ class FlashInferLoRABackend(BaseLoRABackend):
weights
=
gate_up_lora_b
[
1
],
weights
=
gate_up_lora_b
[
1
],
)
)
return
lora_output
return
lora_output
*
self
.
batch_info
.
scalings
[
0
]
python/sglang/srt/lora/backend/triton_backend.py
View file @
ef9a378a
...
@@ -25,11 +25,10 @@ class TritonLoRABackend(BaseLoRABackend):
...
@@ -25,11 +25,10 @@ class TritonLoRABackend(BaseLoRABackend):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
base_output
:
torch
.
Tensor
=
None
,
base_output
:
torch
.
Tensor
=
None
,
scaling
:
float
=
1.0
,
*
args
,
*
args
,
**
kwargs
**
kwargs
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
sgemm_lora_b_fwd
(
x
,
weights
,
self
.
batch_info
,
base_output
,
scaling
)
return
sgemm_lora_b_fwd
(
x
,
weights
,
self
.
batch_info
,
base_output
)
def
run_qkv_lora
(
def
run_qkv_lora
(
self
,
self
,
...
@@ -39,7 +38,6 @@ class TritonLoRABackend(BaseLoRABackend):
...
@@ -39,7 +38,6 @@ class TritonLoRABackend(BaseLoRABackend):
output_offset
:
torch
.
Tensor
,
output_offset
:
torch
.
Tensor
,
max_qkv_out_dim
:
int
,
max_qkv_out_dim
:
int
,
base_output
:
torch
.
Tensor
=
None
,
base_output
:
torch
.
Tensor
=
None
,
scaling
:
float
=
1.0
,
*
args
,
*
args
,
**
kwargs
**
kwargs
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -49,7 +47,7 @@ class TritonLoRABackend(BaseLoRABackend):
...
@@ -49,7 +47,7 @@ class TritonLoRABackend(BaseLoRABackend):
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
assert
isinstance
(
qkv_lora_b
,
torch
.
Tensor
)
assert
isinstance
(
qkv_lora_b
,
torch
.
Tensor
)
lora_a_output
=
sgemm_lora_a_fwd
(
x
,
qkv_lora_a
,
self
.
batch_info
)
lora_a_output
=
sgemm_lora_a_fwd
(
x
,
qkv_lora_a
,
self
.
batch_info
,
stack_num
=
3
)
lora_output
=
qkv_lora_b_fwd
(
lora_output
=
qkv_lora_b_fwd
(
lora_a_output
,
lora_a_output
,
qkv_lora_b
,
qkv_lora_b
,
...
@@ -57,7 +55,6 @@ class TritonLoRABackend(BaseLoRABackend):
...
@@ -57,7 +55,6 @@ class TritonLoRABackend(BaseLoRABackend):
output_offset
,
output_offset
,
max_qkv_out_dim
,
max_qkv_out_dim
,
base_output
,
base_output
,
scaling
,
)
)
return
lora_output
return
lora_output
...
@@ -67,7 +64,6 @@ class TritonLoRABackend(BaseLoRABackend):
...
@@ -67,7 +64,6 @@ class TritonLoRABackend(BaseLoRABackend):
gate_up_lora_a
:
torch
.
Tensor
,
gate_up_lora_a
:
torch
.
Tensor
,
gate_up_lora_b
:
torch
.
Tensor
,
gate_up_lora_b
:
torch
.
Tensor
,
base_output
:
torch
.
Tensor
=
None
,
base_output
:
torch
.
Tensor
=
None
,
scaling
:
float
=
1.0
,
*
args
,
*
args
,
**
kwargs
**
kwargs
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -79,13 +75,14 @@ class TritonLoRABackend(BaseLoRABackend):
...
@@ -79,13 +75,14 @@ class TritonLoRABackend(BaseLoRABackend):
output_dim
=
gate_up_lora_b
.
shape
[
-
2
]
//
2
output_dim
=
gate_up_lora_b
.
shape
[
-
2
]
//
2
# lora_a_output: (s, 2 * r)
# lora_a_output: (s, 2 * r)
lora_a_output
=
sgemm_lora_a_fwd
(
x
,
gate_up_lora_a
,
self
.
batch_info
)
lora_a_output
=
sgemm_lora_a_fwd
(
x
,
gate_up_lora_a
,
self
.
batch_info
,
stack_num
=
2
)
lora_output
=
gate_up_lora_b_fwd
(
lora_output
=
gate_up_lora_b_fwd
(
lora_a_output
,
lora_a_output
,
gate_up_lora_b
,
gate_up_lora_b
,
self
.
batch_info
,
self
.
batch_info
,
output_dim
,
output_dim
,
base_output
,
base_output
,
scaling
,
)
)
return
lora_output
return
lora_output
python/sglang/srt/lora/layers.py
View file @
ef9a378a
...
@@ -23,14 +23,10 @@ class BaseLayerWithLoRA(nn.Module):
...
@@ -23,14 +23,10 @@ class BaseLayerWithLoRA(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
base_layer
:
nn
.
Module
,
base_layer
:
nn
.
Module
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
lora_backend
:
BaseLoRABackend
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
base_layer
:
nn
.
Module
=
base_layer
self
.
base_layer
:
nn
.
Module
=
base_layer
self
.
lora_rank
:
int
=
lora_rank
self
.
scaling
:
float
=
scaling
self
.
set_lora
:
bool
=
False
self
.
set_lora
:
bool
=
False
self
.
lora_backend
:
BaseLoRABackend
=
lora_backend
self
.
lora_backend
:
BaseLoRABackend
=
lora_backend
...
@@ -59,11 +55,9 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -59,11 +55,9 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def
__init__
(
def
__init__
(
self
,
self
,
base_layer
:
VocabParallelEmbedding
,
base_layer
:
VocabParallelEmbedding
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
super
().
__init__
(
base_layer
,
lora_backend
)
self
.
weight
=
base_layer
.
weight
self
.
weight
=
base_layer
.
weight
...
@@ -71,11 +65,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -71,11 +65,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def
__init__
(
def
__init__
(
self
,
self
,
base_layer
:
ColumnParallelLinear
,
base_layer
:
ColumnParallelLinear
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
super
().
__init__
(
base_layer
,
lora_backend
)
def
set_lora_info
(
def
set_lora_info
(
self
,
self
,
...
@@ -87,7 +79,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -87,7 +79,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self
.
B_buffer
=
B_buffer
self
.
B_buffer
=
B_buffer
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
backend_kwargs
=
{
"base_output"
:
base_output
,
"scaling"
:
self
.
scaling
}
backend_kwargs
=
{
"base_output"
:
base_output
}
lora_a_output
=
self
.
lora_backend
.
run_lora_a_sgemm
(
x
,
self
.
A_buffer
)
lora_a_output
=
self
.
lora_backend
.
run_lora_a_sgemm
(
x
,
self
.
A_buffer
)
lora_output
=
self
.
lora_backend
.
run_lora_b_sgemm
(
lora_output
=
self
.
lora_backend
.
run_lora_b_sgemm
(
lora_a_output
,
lora_a_output
,
...
@@ -96,8 +88,8 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -96,8 +88,8 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
)
)
return
(
return
(
lora_output
lora_output
if
self
.
lora_backend
.
fuse_output_
scaling_
add
if
self
.
lora_backend
.
fuse_output_add
else
base_output
+
lora_output
*
self
.
scaling
else
base_output
+
lora_output
)
)
def
forward
(
self
,
input_
:
torch
.
Tensor
):
def
forward
(
self
,
input_
:
torch
.
Tensor
):
...
@@ -132,11 +124,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -132,11 +124,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def
__init__
(
def
__init__
(
self
,
self
,
base_layer
:
MergedColumnParallelLinear
,
base_layer
:
MergedColumnParallelLinear
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
super
().
__init__
(
base_layer
,
lora_backend
)
def
set_lora_info
(
def
set_lora_info
(
self
,
self
,
...
@@ -155,7 +145,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -155,7 +145,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self
.
B_buffer_gate_up
=
(
B_buffer
[
0
],
B_buffer
[
1
])
self
.
B_buffer_gate_up
=
(
B_buffer
[
0
],
B_buffer
[
1
])
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
backend_kwargs
=
{
"base_output"
:
base_output
,
"scaling"
:
self
.
scaling
}
backend_kwargs
=
{
"base_output"
:
base_output
}
lora_output
=
self
.
lora_backend
.
run_gate_up_lora
(
lora_output
=
self
.
lora_backend
.
run_gate_up_lora
(
x
,
x
,
...
@@ -165,8 +155,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -165,8 +155,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
)
)
return
(
return
(
lora_output
lora_output
if
self
.
lora_backend
.
fuse_output_
scaling_
add
if
self
.
lora_backend
.
fuse_output_add
else
base_output
+
lora_output
*
self
.
scaling
else
base_output
+
lora_output
)
)
def
slice_lora_a_weights
(
self
,
A
:
torch
.
Tensor
,
tp_rank
:
int
):
def
slice_lora_a_weights
(
self
,
A
:
torch
.
Tensor
,
tp_rank
:
int
):
...
@@ -184,11 +174,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -184,11 +174,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def
init__
(
def
init__
(
self
,
self
,
base_layer
:
QKVParallelLinear
,
base_layer
:
QKVParallelLinear
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
super
().
__init__
(
base_layer
,
lora_backend
)
def
set_lora_info
(
def
set_lora_info
(
self
,
self
,
...
@@ -230,7 +218,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -230,7 +218,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
)
)
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
backend_kwargs
=
{
"base_output"
:
base_output
,
"scaling"
:
self
.
scaling
}
backend_kwargs
=
{
"base_output"
:
base_output
}
if
self
.
lora_backend
.
fuse_stacked_lora_b
:
if
self
.
lora_backend
.
fuse_stacked_lora_b
:
backend_kwargs
[
"output_offset"
]
=
self
.
output_offset
backend_kwargs
[
"output_offset"
]
=
self
.
output_offset
backend_kwargs
[
"max_qkv_out_dim"
]
=
self
.
max_qkv_out_dim
backend_kwargs
[
"max_qkv_out_dim"
]
=
self
.
max_qkv_out_dim
...
@@ -243,8 +231,8 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -243,8 +231,8 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
)
)
return
(
return
(
lora_output
lora_output
if
self
.
lora_backend
.
fuse_output_
scaling_
add
if
self
.
lora_backend
.
fuse_output_add
else
base_output
+
lora_output
*
self
.
scaling
else
base_output
+
lora_output
)
)
def
slice_lora_a_weights
(
self
,
A
:
torch
.
Tensor
,
tp_rank
:
int
):
def
slice_lora_a_weights
(
self
,
A
:
torch
.
Tensor
,
tp_rank
:
int
):
...
@@ -273,11 +261,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -273,11 +261,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def
__init__
(
def
__init__
(
self
,
self
,
base_layer
:
RowParallelLinear
,
base_layer
:
RowParallelLinear
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
super
().
__init__
(
base_layer
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer
:
torch
.
Tensor
,
B_buffer
:
torch
.
Tensor
):
def
set_lora_info
(
self
,
A_buffer
:
torch
.
Tensor
,
B_buffer
:
torch
.
Tensor
):
self
.
set_lora
=
True
self
.
set_lora
=
True
...
@@ -285,7 +271,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -285,7 +271,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self
.
B_buffer
=
B_buffer
self
.
B_buffer
=
B_buffer
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
backend_kwargs
=
{
"base_output"
:
base_output
,
"scaling"
:
self
.
scaling
}
backend_kwargs
=
{
"base_output"
:
base_output
}
lora_a_output
=
self
.
lora_backend
.
run_lora_a_sgemm
(
x
,
self
.
A_buffer
)
lora_a_output
=
self
.
lora_backend
.
run_lora_a_sgemm
(
x
,
self
.
A_buffer
)
lora_output
=
self
.
lora_backend
.
run_lora_b_sgemm
(
lora_output
=
self
.
lora_backend
.
run_lora_b_sgemm
(
lora_a_output
,
lora_a_output
,
...
@@ -294,8 +280,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -294,8 +280,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
)
)
return
(
return
(
lora_output
lora_output
if
self
.
lora_backend
.
fuse_output_
scaling_
add
if
self
.
lora_backend
.
fuse_output_add
else
base_output
+
lora_output
*
self
.
scaling
else
base_output
+
lora_output
)
)
def
forward
(
self
,
input_
:
torch
.
Tensor
):
def
forward
(
self
,
input_
:
torch
.
Tensor
):
...
@@ -344,7 +330,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -344,7 +330,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def
get_lora_layer
(
def
get_lora_layer
(
layer
:
nn
.
Module
,
lora_rank
:
int
,
scaling
:
int
,
lora_backend
:
BaseLoRABackend
layer
:
nn
.
Module
,
lora_backend
:
BaseLoRABackend
)
->
BaseLayerWithLoRA
:
)
->
BaseLayerWithLoRA
:
supported_layer_types
=
{
supported_layer_types
=
{
# the order matters
# the order matters
...
@@ -356,6 +342,6 @@ def get_lora_layer(
...
@@ -356,6 +342,6 @@ def get_lora_layer(
}
}
for
src_layer_type
,
lora_layer_type
in
supported_layer_types
.
items
():
for
src_layer_type
,
lora_layer_type
in
supported_layer_types
.
items
():
if
isinstance
(
layer
,
src_layer_type
):
# pylint: disable=unidiomatic-typecheck
if
isinstance
(
layer
,
src_layer_type
):
# pylint: disable=unidiomatic-typecheck
ret
=
lora_layer_type
(
layer
,
lora_rank
,
scaling
,
lora_backend
)
ret
=
lora_layer_type
(
layer
,
lora_backend
)
return
ret
return
ret
raise
Exception
(
f
"No corresponding LoRA layer supported for
{
type
(
layer
)
}
."
)
raise
Exception
(
f
"No corresponding LoRA layer supported for
{
type
(
layer
)
}
."
)
python/sglang/srt/lora/lora_manager.py
View file @
ef9a378a
...
@@ -103,11 +103,14 @@ class LoRAManager:
...
@@ -103,11 +103,14 @@ class LoRAManager:
self
.
loras
[
name
]
=
lora_adapter
self
.
loras
[
name
]
=
lora_adapter
# misc lora configs
# misc lora configs
# FIXME remove the restrictions after implementing unified paging
self
.
max_lora_dim
:
int
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
self
.
max_lora_dim
:
int
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
self
.
scaling
:
float
=
list
(
self
.
loras
.
values
())[
0
].
scaling
assert
all
(
x
.
hf_config
[
"r"
]
==
self
.
max_lora_dim
for
x
in
self
.
configs
.
values
())
if
self
.
lora_backend
==
"flashinfer"
:
assert
all
(
x
.
scaling
==
self
.
scaling
for
x
in
self
.
loras
.
values
())
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
max_lora_dim
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
scaling
=
list
(
self
.
loras
.
values
())[
0
].
scaling
assert
all
(
x
.
hf_config
[
"r"
]
==
max_lora_dim
for
x
in
self
.
configs
.
values
())
assert
all
(
x
.
scaling
==
scaling
for
x
in
self
.
loras
.
values
())
# Convert original model layers to layers with LoRA
# Convert original model layers to layers with LoRA
self
.
convert_to_lora_layers
()
self
.
convert_to_lora_layers
()
...
@@ -133,6 +136,10 @@ class LoRAManager:
...
@@ -133,6 +136,10 @@ class LoRAManager:
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
self
.
memory_pool
.
prepare_lora_batch
(
cur_uids
,
self
.
loras
)
self
.
memory_pool
.
prepare_lora_batch
(
cur_uids
,
self
.
loras
)
# FIXME: Handle lora uid with None more safely
if
cur_uids
==
set
([
None
]):
return
# set up batch info shared by all lora moruldes
# set up batch info shared by all lora moruldes
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
seg_lens
=
(
seg_lens
=
(
...
@@ -144,8 +151,18 @@ class LoRAManager:
...
@@ -144,8 +151,18 @@ class LoRAManager:
seg_indptr
[
1
:]
=
torch
.
cumsum
(
seg_lens
,
dim
=
0
)
seg_indptr
[
1
:]
=
torch
.
cumsum
(
seg_lens
,
dim
=
0
)
max_len
=
int
(
torch
.
max
(
seg_lens
))
max_len
=
int
(
torch
.
max
(
seg_lens
))
weight_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
weight_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
lora_ranks
=
torch
.
empty
(
(
self
.
max_loras_per_batch
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
scalings
=
torch
.
empty
(
(
self
.
max_loras_per_batch
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
for
i
,
lora_path
in
enumerate
(
forward_batch
.
lora_paths
):
for
i
,
lora_path
in
enumerate
(
forward_batch
.
lora_paths
):
weight_indices
[
i
]
=
self
.
memory_pool
.
get_buffer_id
(
lora_path
)
weight_indices
[
i
]
=
self
.
memory_pool
.
get_buffer_id
(
lora_path
)
lora
=
self
.
loras
[
lora_path
]
lora_ranks
[
weight_indices
[
i
]]
=
lora
.
config
.
hf_config
[
"r"
]
scalings
[
weight_indices
[
i
]]
=
lora
.
scaling
batch_info
=
LoRABatchInfo
(
batch_info
=
LoRABatchInfo
(
bs
=
bs
,
bs
=
bs
,
...
@@ -153,6 +170,8 @@ class LoRAManager:
...
@@ -153,6 +170,8 @@ class LoRAManager:
seg_indptr
=
seg_indptr
,
seg_indptr
=
seg_indptr
,
max_len
=
max_len
,
max_len
=
max_len
,
weight_indices
=
weight_indices
,
weight_indices
=
weight_indices
,
lora_ranks
=
lora_ranks
,
scalings
=
scalings
,
)
)
self
.
lora_backend
.
set_batch_info
(
batch_info
)
self
.
lora_backend
.
set_batch_info
(
batch_info
)
...
@@ -185,9 +204,7 @@ class LoRAManager:
...
@@ -185,9 +204,7 @@ class LoRAManager:
)
)
def
set_lora_module
(
self
,
module_name
,
module
):
def
set_lora_module
(
self
,
module_name
,
module
):
lora_module
=
get_lora_layer
(
lora_module
=
get_lora_layer
(
module
,
self
.
lora_backend
)
module
,
self
.
max_lora_dim
,
self
.
scaling
,
self
.
lora_backend
)
replace_submodule
(
self
.
base_model
,
module_name
,
lora_module
)
replace_submodule
(
self
.
base_model
,
module_name
,
lora_module
)
return
lora_module
return
lora_module
...
...
python/sglang/srt/lora/mem_pool.py
View file @
ef9a378a
...
@@ -167,6 +167,7 @@ class LoRAMemoryPool:
...
@@ -167,6 +167,7 @@ class LoRAMemoryPool:
return
return
assert
lora_adapter
is
not
None
assert
lora_adapter
is
not
None
lora_rank
=
lora_adapter
.
config
.
hf_config
[
"r"
]
for
layer_id
in
range
(
self
.
num_layer
):
for
layer_id
in
range
(
self
.
num_layer
):
layer_weights
=
lora_adapter
.
layers
[
layer_id
].
weights
layer_weights
=
lora_adapter
.
layers
[
layer_id
].
weights
temp_A_buffer
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
temp_A_buffer
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
...
@@ -208,17 +209,22 @@ class LoRAMemoryPool:
...
@@ -208,17 +209,22 @@ class LoRAMemoryPool:
)
)
for
name
,
weights
in
temp_A_buffer
.
items
():
for
name
,
weights
in
temp_A_buffer
.
items
():
self
.
A_buffer
[
name
][
layer_id
][
buffer_id
].
copy_
(
weights
)
c
=
get_stacked_multiply
(
name
)
self
.
A_buffer
[
name
][
layer_id
][
buffer_id
][:
lora_rank
*
c
,
:].
copy_
(
weights
)
for
name
,
weights
in
temp_B_buffer
.
items
():
for
name
,
weights
in
temp_B_buffer
.
items
():
c
=
get_stacked_multiply
(
name
)
c
=
get_stacked_multiply
(
name
)
if
c
>
1
:
if
c
>
1
:
for
stacked_id
in
range
(
c
):
for
stacked_id
in
range
(
c
):
self
.
B_buffer
[
name
][
layer_id
][
stacked_id
][
buffer_id
]
.
copy_
(
self
.
B_buffer
[
name
][
layer_id
][
stacked_id
][
buffer_id
]
[
weights
[
stacked_id
]
:,
:
lora_rank
)
].
copy_
(
weights
[
stacked_id
]
)
else
:
else
:
self
.
B_buffer
[
name
][
layer_id
][
0
][
buffer_id
].
copy_
(
weights
)
self
.
B_buffer
[
name
][
layer_id
][
0
][
buffer_id
][:,
:
lora_rank
].
copy_
(
weights
)
def
get_tensor
(
def
get_tensor
(
self
,
weight_name
:
str
,
layer_id
:
int
,
lora_type
:
LoRAType
self
,
weight_name
:
str
,
layer_id
:
int
,
lora_type
:
LoRAType
...
...
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
View file @
ef9a378a
...
@@ -22,17 +22,18 @@ def _gate_up_lora_b_kernel(
...
@@ -22,17 +22,18 @@ def _gate_up_lora_b_kernel(
w_stride_2
,
w_stride_2
,
output_stride_0
,
output_stride_0
,
output_stride_1
,
output_stride_1
,
# Information on sequence lengths and weight id
# Information on sequence lengths
,ranks
and weight id
seg_lens
,
seg_lens
,
seg_indptr
,
seg_indptr
,
weight_indices
,
weight_indices
,
lora_ranks
,
# Meta parameters
# Meta parameters
BLOCK_S
:
tl
.
constexpr
,
BLOCK_S
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# For fused output scaling and adding
# For fused output scaling and adding
fuse_scaling_add
,
fuse_scaling_add
,
scaling
,
scaling
s
,
):
):
# This kernel packs 2 sgemms (gate/up) into a single kernel.
# This kernel packs 2 sgemms (gate/up) into a single kernel.
...
@@ -51,6 +52,11 @@ def _gate_up_lora_b_kernel(
...
@@ -51,6 +52,11 @@ def _gate_up_lora_b_kernel(
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
n_start
=
gate_up_id
*
output_dim
# offset on output dim
n_start
=
gate_up_id
*
output_dim
# offset on output dim
rank
=
tl
.
load
(
lora_ranks
+
w_index
)
scaling
=
tl
.
load
(
scalings
+
w_index
)
# Adjust K (rank) according to the specific LoRA adapter
K
=
tl
.
minimum
(
K
,
rank
)
# The tile in output matrix will have (pid_s, pid_n) as id
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n
=
tl
.
cdiv
(
output_dim
,
BLOCK_N
)
num_pid_n
=
tl
.
cdiv
(
output_dim
,
BLOCK_N
)
...
@@ -109,7 +115,6 @@ def gate_up_lora_b_fwd(
...
@@ -109,7 +115,6 @@ def gate_up_lora_b_fwd(
batch_info
:
LoRABatchInfo
,
batch_info
:
LoRABatchInfo
,
output_dim
:
int
,
output_dim
:
int
,
base_output
:
torch
.
Tensor
=
None
,
base_output
:
torch
.
Tensor
=
None
,
scaling
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# x: (s, 2 * r)
# x: (s, 2 * r)
...
@@ -160,11 +165,12 @@ def gate_up_lora_b_fwd(
...
@@ -160,11 +165,12 @@ def gate_up_lora_b_fwd(
batch_info
.
seg_lens
,
batch_info
.
seg_lens
,
batch_info
.
seg_indptr
,
batch_info
.
seg_indptr
,
batch_info
.
weight_indices
,
batch_info
.
weight_indices
,
batch_info
.
lora_ranks
,
BLOCK_S
,
BLOCK_S
,
BLOCK_OUT
,
BLOCK_OUT
,
BLOCK_R
,
BLOCK_R
,
fuse_scaling_add
,
fuse_scaling_add
,
scaling
,
batch_info
.
scaling
s
,
)
)
return
output
return
output
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
View file @
ef9a378a
...
@@ -26,6 +26,7 @@ def _qkv_lora_b_kernel(
...
@@ -26,6 +26,7 @@ def _qkv_lora_b_kernel(
seg_lens
,
seg_lens
,
seg_indptr
,
seg_indptr
,
weight_indices
,
weight_indices
,
lora_ranks
,
# Offsets of q/k/v slice on output dimension
# Offsets of q/k/v slice on output dimension
n_offs
,
n_offs
,
# Meta parameters
# Meta parameters
...
@@ -34,7 +35,7 @@ def _qkv_lora_b_kernel(
...
@@ -34,7 +35,7 @@ def _qkv_lora_b_kernel(
BLOCK_K
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# For fused output scaling and adding
# For fused output scaling and adding
fuse_scaling_add
,
fuse_scaling_add
,
scaling
,
scaling
s
,
):
):
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
...
@@ -54,6 +55,10 @@ def _qkv_lora_b_kernel(
...
@@ -54,6 +55,10 @@ def _qkv_lora_b_kernel(
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
n_start
=
tl
.
load
(
n_offs
+
qkv_id
)
n_start
=
tl
.
load
(
n_offs
+
qkv_id
)
n_size
=
tl
.
load
(
n_offs
+
qkv_id
+
1
)
-
n_start
n_size
=
tl
.
load
(
n_offs
+
qkv_id
+
1
)
-
n_start
rank
=
tl
.
load
(
lora_ranks
+
w_index
)
scaling
=
tl
.
load
(
scalings
+
w_index
)
# Adjust K (rank) according to the specific LoRA adapter
K
=
tl
.
minimum
(
K
,
rank
)
# The tile in output matrix will have (pid_s, pid_n) as id
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n
=
tl
.
cdiv
(
max_qkv_out_dim
,
BLOCK_N
)
num_pid_n
=
tl
.
cdiv
(
max_qkv_out_dim
,
BLOCK_N
)
...
@@ -112,7 +117,6 @@ def qkv_lora_b_fwd(
...
@@ -112,7 +117,6 @@ def qkv_lora_b_fwd(
output_offset
:
torch
.
Tensor
,
output_offset
:
torch
.
Tensor
,
max_qkv_out_dim
:
int
,
max_qkv_out_dim
:
int
,
base_output
:
torch
.
Tensor
=
None
,
base_output
:
torch
.
Tensor
=
None
,
scaling
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# x: (s, 3 * r)
# x: (s, 3 * r)
...
@@ -171,12 +175,13 @@ def qkv_lora_b_fwd(
...
@@ -171,12 +175,13 @@ def qkv_lora_b_fwd(
batch_info
.
seg_lens
,
batch_info
.
seg_lens
,
batch_info
.
seg_indptr
,
batch_info
.
seg_indptr
,
batch_info
.
weight_indices
,
batch_info
.
weight_indices
,
batch_info
.
lora_ranks
,
output_offset
,
output_offset
,
BLOCK_S
,
BLOCK_S
,
BLOCK_OUT
,
BLOCK_OUT
,
BLOCK_R
,
BLOCK_R
,
fuse_scaling_add
,
fuse_scaling_add
,
scaling
,
batch_info
.
scaling
s
,
)
)
return
output
return
output
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
View file @
ef9a378a
...
@@ -12,8 +12,9 @@ def _sgemm_lora_a_kernel(
...
@@ -12,8 +12,9 @@ def _sgemm_lora_a_kernel(
weights
,
weights
,
output
,
output
,
# Matrix dimensions
# Matrix dimensions
N
,
# r
N
,
#
stack_num *
r
K
,
# input_dim
K
,
# input_dim
stack_num
,
# Strides
# Strides
x_stride_0
,
x_stride_0
,
x_stride_1
,
x_stride_1
,
...
@@ -22,10 +23,11 @@ def _sgemm_lora_a_kernel(
...
@@ -22,10 +23,11 @@ def _sgemm_lora_a_kernel(
w_stride_2
,
w_stride_2
,
output_stride_0
,
output_stride_0
,
output_stride_1
,
output_stride_1
,
# Information on sequence lengths and weight id
# Information on sequence lengths
,ranks
and weight id
seg_lens
,
seg_lens
,
seg_indptr
,
seg_indptr
,
weight_indices
,
weight_indices
,
lora_ranks
,
# Meta parameters
# Meta parameters
BLOCK_S
:
tl
.
constexpr
,
BLOCK_S
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
@@ -43,6 +45,9 @@ def _sgemm_lora_a_kernel(
...
@@ -43,6 +45,9 @@ def _sgemm_lora_a_kernel(
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
rank
=
tl
.
load
(
lora_ranks
+
w_index
)
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
N
=
tl
.
minimum
(
N
,
rank
*
stack_num
)
# The tile in output matrix will have (pid_s, pid_n) as id
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
...
@@ -91,11 +96,15 @@ def _sgemm_lora_a_kernel(
...
@@ -91,11 +96,15 @@ def _sgemm_lora_a_kernel(
def
sgemm_lora_a_fwd
(
def
sgemm_lora_a_fwd
(
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
batch_info
:
LoRABatchInfo
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
batch_info
:
LoRABatchInfo
,
stack_num
:
int
=
1
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# x: (s, input_dim)
# x: (s, input_dim)
# weights: (num_lora, r, input_dim)
# weights: (num_lora, stack_num * r, input_dim)
# output: (s, r)
# output: (s, stack_num * r)
# stack_num: run_qkv_lora: 3, run_gate_up_lora: 2
# when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
# when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
# input_dim is much larger than r
# input_dim is much larger than r
...
@@ -126,6 +135,7 @@ def sgemm_lora_a_fwd(
...
@@ -126,6 +135,7 @@ def sgemm_lora_a_fwd(
output
,
output
,
R
,
R
,
K
,
K
,
stack_num
,
x
.
stride
(
0
),
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
1
),
weights
.
stride
(
0
),
weights
.
stride
(
0
),
...
@@ -136,6 +146,7 @@ def sgemm_lora_a_fwd(
...
@@ -136,6 +146,7 @@ def sgemm_lora_a_fwd(
batch_info
.
seg_lens
,
batch_info
.
seg_lens
,
batch_info
.
seg_indptr
,
batch_info
.
seg_indptr
,
batch_info
.
weight_indices
,
batch_info
.
weight_indices
,
batch_info
.
lora_ranks
,
BLOCK_S
,
BLOCK_S
,
BLOCK_R
,
BLOCK_R
,
BLOCK_K
,
BLOCK_K
,
...
...
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
View file @
ef9a378a
...
@@ -26,13 +26,14 @@ def _sgemm_lora_b_kernel(
...
@@ -26,13 +26,14 @@ def _sgemm_lora_b_kernel(
seg_lens
,
seg_lens
,
seg_indptr
,
seg_indptr
,
weight_indices
,
weight_indices
,
lora_ranks
,
# Meta parameters
# Meta parameters
BLOCK_S
:
tl
.
constexpr
,
BLOCK_S
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# For fused output scaling and adding
# For fused output scaling and adding
fuse_scaling_add
,
fuse_scaling_add
,
scaling
,
scaling
s
,
):
):
# x: (s, K), s is the sum of sequence lengths
# x: (s, K), s is the sum of sequence lengths
# weights: (num_lora, N, K)
# weights: (num_lora, N, K)
...
@@ -45,6 +46,10 @@ def _sgemm_lora_b_kernel(
...
@@ -45,6 +46,10 @@ def _sgemm_lora_b_kernel(
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
rank
=
tl
.
load
(
lora_ranks
+
w_index
)
scaling
=
tl
.
load
(
scalings
+
w_index
)
# Adjust K (rank) according to the specific LoRA adapter
K
=
tl
.
minimum
(
K
,
rank
)
# The tile in output matrix will have (pid_s, pid_n) as id
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
...
@@ -100,12 +105,11 @@ def sgemm_lora_b_fwd(
...
@@ -100,12 +105,11 @@ def sgemm_lora_b_fwd(
weights
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
batch_info
:
LoRABatchInfo
,
batch_info
:
LoRABatchInfo
,
base_output
:
torch
.
Tensor
=
None
,
base_output
:
torch
.
Tensor
=
None
,
scaling
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# x: (s, r)
# x: (s,
max_
r)
# weights: (num_lora, output_dim, r)
# weights: (num_lora, output_dim,
max_
r)
# output: (s, output_dim)
# output: (s, output_dim)
# output_dim is much larger than r
# output_dim is much larger than
max_
r
assert
x
.
is_contiguous
()
assert
x
.
is_contiguous
()
assert
weights
.
is_contiguous
()
assert
weights
.
is_contiguous
()
...
@@ -150,10 +154,11 @@ def sgemm_lora_b_fwd(
...
@@ -150,10 +154,11 @@ def sgemm_lora_b_fwd(
batch_info
.
seg_lens
,
batch_info
.
seg_lens
,
batch_info
.
seg_indptr
,
batch_info
.
seg_indptr
,
batch_info
.
weight_indices
,
batch_info
.
weight_indices
,
batch_info
.
lora_ranks
,
BLOCK_S
,
BLOCK_S
,
BLOCK_N
,
BLOCK_N
,
BLOCK_R
,
BLOCK_R
,
fuse_scaling_add
,
fuse_scaling_add
,
scaling
,
batch_info
.
scaling
s
,
)
)
return
output
return
output
python/sglang/srt/lora/utils.py
View file @
ef9a378a
...
@@ -25,6 +25,12 @@ class LoRABatchInfo:
...
@@ -25,6 +25,12 @@ class LoRABatchInfo:
# The index of lora adapter used by each sequence, in shape (bs,)
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices
:
torch
.
Tensor
weight_indices
:
torch
.
Tensor
# ranks of each lora adapter, in shape (lora_num,)
lora_ranks
:
torch
.
Tensor
# scaling of each lora adapter, in shape (lora_num,)
scalings
:
torch
.
Tensor
class
LoRAType
(
Enum
):
class
LoRAType
(
Enum
):
LORA_A
=
0
LORA_A
=
0
...
...
test/srt/models/lora/test_lora.py
View file @
ef9a378a
...
@@ -29,7 +29,7 @@ LORA_SETS = [
...
@@ -29,7 +29,7 @@ LORA_SETS = [
# {"base": "Qwen/Qwen2.5-14B-Instruct", "loras": ["mssongit/Qwen2.5-14B-SFT-LoRA"]},
# {"base": "Qwen/Qwen2.5-14B-Instruct", "loras": ["mssongit/Qwen2.5-14B-SFT-LoRA"]},
# {"base": "mistralai/Mistral-7B-Instruct-v0.3", "loras": ["/home/ying/test_lora"]},
# {"base": "mistralai/Mistral-7B-Instruct-v0.3", "loras": ["/home/ying/test_lora"]},
# {
# {
#
"base": "mistralai/Mistral-7B-Instruct-v0.3",
# "base": "mistralai/Mistral-7B-Instruct-v0.3",
# "loras": [
# "loras": [
# "/home/ying/test_lora",
# "/home/ying/test_lora",
# "/home/ying/test_lora_1",
# "/home/ying/test_lora_1",
...
@@ -176,9 +176,11 @@ class TestLoRA(CustomTestCase):
...
@@ -176,9 +176,11 @@ class TestLoRA(CustomTestCase):
print
(
f
"
{
srt_no_lora_outputs
.
output_strs
=
}
"
)
print
(
f
"
{
srt_no_lora_outputs
.
output_strs
=
}
"
)
print
(
f
"
{
srt_outputs_lora_path_none
.
output_strs
=
}
"
)
print
(
f
"
{
srt_outputs_lora_path_none
.
output_strs
=
}
"
)
for
i
in
range
(
len
(
prompts
)):
for
i
in
range
(
len
(
prompts
)):
assert
srt_outputs
.
output_strs
[
i
].
strip
(
" "
)
==
hf_outputs
.
output_strs
[
i
],
(
assert
srt_outputs
.
output_strs
[
i
].
strip
(
" "
)
==
hf_outputs
.
output_strs
[
i
].
strip
(
" "
),
(
srt_outputs
.
output_strs
[
i
].
strip
(
" "
),
srt_outputs
.
output_strs
[
i
].
strip
(
" "
),
hf_outputs
.
output_strs
[
i
],
hf_outputs
.
output_strs
[
i
]
.
strip
(
" "
)
,
)
)
assert
(
assert
(
srt_no_lora_outputs
.
output_strs
[
i
].
strip
(
" "
)
srt_no_lora_outputs
.
output_strs
[
i
].
strip
(
" "
)
...
@@ -187,7 +189,7 @@ class TestLoRA(CustomTestCase):
...
@@ -187,7 +189,7 @@ class TestLoRA(CustomTestCase):
srt_no_lora_outputs
.
output_strs
[
i
].
strip
(
" "
),
srt_no_lora_outputs
.
output_strs
[
i
].
strip
(
" "
),
hf_no_lora_outputs
.
output_strs
[
i
],
hf_no_lora_outputs
.
output_strs
[
i
],
)
)
assert
srt_outputs_lora_path_none
==
srt_no_lora_outputs
#
assert srt_outputs_lora_path_none == srt_no_lora_outputs
def
serving
(
self
,
prompts
,
lora_set
,
tp_size
,
torch_dtype
,
max_new_tokens
):
def
serving
(
self
,
prompts
,
lora_set
,
tp_size
,
torch_dtype
,
max_new_tokens
):
print
(
"=================== testing serving ======================="
)
print
(
"=================== testing serving ======================="
)
...
@@ -287,7 +289,7 @@ class TestLoRA(CustomTestCase):
...
@@ -287,7 +289,7 @@ class TestLoRA(CustomTestCase):
tp_size
=
1
tp_size
=
1
max_new_tokens
=
32
max_new_tokens
=
32
self
.
inference
(
PROMPTS
,
lora_set
,
tp_size
,
torch_dtype
,
max_new_tokens
)
self
.
inference
(
PROMPTS
,
lora_set
,
tp_size
,
torch_dtype
,
max_new_tokens
)
self
.
serving
(
PROMPTS
,
lora_set
,
tp_size
,
torch_dtype
,
max_new_tokens
)
#
self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
# self.base_inference(
# self.base_inference(
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
# )
# )
...
...
test/srt/models/lora/test_multi_lora_backend.py
View file @
ef9a378a
...
@@ -19,17 +19,35 @@ from typing import List
...
@@ -19,17 +19,35 @@ from typing import List
import
torch
import
torch
from
utils
import
BACKENDS
,
TORCH_DTYPES
,
LoRAAdaptor
,
LoRAModelCase
from
utils
import
BACKENDS
,
TORCH_DTYPES
,
LoRAAdaptor
,
LoRAModelCase
from
sglang.test.test_utils
import
CustomTestCase
,
is_in_ci
from
sglang.test.runners
import
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
CustomTestCase
,
calculate_rouge_l
,
is_in_ci
MULTI_LORA_MODELS
=
[
MULTI_LORA_MODELS
=
[
# multi-rank case
LoRAModelCase
(
base
=
"meta-llama/Llama-2-7b-hf"
,
adaptors
=
[
LoRAAdaptor
(
name
=
"winddude/wizardLM-LlaMA-LoRA-7B"
,
prefill_tolerance
=
1e-1
,
),
LoRAAdaptor
(
name
=
"RuterNorway/Llama-2-7b-chat-norwegian-LoRa"
,
prefill_tolerance
=
3e-1
,
),
],
max_loras_per_batch
=
2
,
),
LoRAModelCase
(
LoRAModelCase
(
base
=
"meta-llama/Llama-3.1-8B-Instruct"
,
base
=
"meta-llama/Llama-3.1-8B-Instruct"
,
adaptors
=
[
adaptors
=
[
LoRAAdaptor
(
LoRAAdaptor
(
name
=
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
,
name
=
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
,
prefill_tolerance
=
1e-1
,
),
),
LoRAAdaptor
(
LoRAAdaptor
(
name
=
"some-org/another-lora-adaptor"
,
name
=
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
,
prefill_tolerance
=
1e-1
,
),
),
],
],
max_loras_per_batch
=
2
,
max_loras_per_batch
=
2
,
...
@@ -64,6 +82,7 @@ class TestMultiLoRABackend(CustomTestCase):
...
@@ -64,6 +82,7 @@ class TestMultiLoRABackend(CustomTestCase):
The multi-LoRA backend test functionality is not supported yet.
The multi-LoRA backend test functionality is not supported yet.
This function uses all prompts at once and prints a message indicating that support is pending.
This function uses all prompts at once and prints a message indicating that support is pending.
"""
"""
base_path
=
model_case
.
base
adaptor_names
=
[
adaptor
.
name
for
adaptor
in
model_case
.
adaptors
]
adaptor_names
=
[
adaptor
.
name
for
adaptor
in
model_case
.
adaptors
]
print
(
print
(
f
"
\n
========== Testing multi-LoRA backend '
{
backend
}
' for base '
{
model_case
.
base
}
' --- "
f
"
\n
========== Testing multi-LoRA backend '
{
backend
}
' for base '
{
model_case
.
base
}
' --- "
...
@@ -72,6 +91,118 @@ class TestMultiLoRABackend(CustomTestCase):
...
@@ -72,6 +91,118 @@ class TestMultiLoRABackend(CustomTestCase):
print
(
print
(
"run_backend_batch: Multi-LoRA backend test functionality is pending support."
"run_backend_batch: Multi-LoRA backend test functionality is pending support."
)
)
with
SRTRunner
(
base_path
,
torch_dtype
=
torch_dtype
,
model_type
=
"generation"
,
tp_size
=
model_case
.
tp_size
,
lora_paths
=
[
adaptor
.
name
for
adaptor
in
model_case
.
adaptors
],
max_loras_per_batch
=
model_case
.
max_loras_per_batch
,
lora_backend
=
backend
,
disable_cuda_graph
=
True
,
disable_radix_cache
=
True
,
mem_fraction_static
=
0.88
,
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
,
lora_paths
=
adaptor_names
)
with
HFRunner
(
base_path
,
torch_dtype
=
torch_dtype
,
model_type
=
"generation"
)
as
hf_runner
:
hf_outputs
=
hf_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
,
lora_paths
=
adaptor_names
)
with
SRTRunner
(
base_path
,
torch_dtype
=
torch_dtype
,
model_type
=
"generation"
,
tp_size
=
model_case
.
tp_size
,
mem_fraction_static
=
0.88
,
)
as
srt_runner
:
srt_no_lora_outputs
=
srt_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
)
with
HFRunner
(
base_path
,
torch_dtype
=
torch_dtype
,
model_type
=
"generation"
,
)
as
hf_runner
:
hf_no_lora_outputs
=
hf_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
)
# Compare prefill stage logprobs (HF vs SRTRunner with LoRA)
for
i
in
range
(
len
(
prompts
)):
adaptor
=
model_case
.
adaptors
[
i
]
# Use individual adapter tolerances if set, otherwise use model defaults
prefill_tol
=
(
adaptor
.
prefill_tolerance
if
adaptor
.
prefill_tolerance
is
not
None
else
model_case
.
prefill_tolerance
)
decode_tol
=
(
adaptor
.
decode_tolerance
if
adaptor
.
decode_tolerance
is
not
None
else
model_case
.
decode_tolerance
)
rouge_tol
=
(
adaptor
.
rouge_l_tolerance
if
adaptor
.
rouge_l_tolerance
is
not
None
else
model_case
.
rouge_l_tolerance
)
# Compare prefill stage logprobs (HF vs SRTRunner with LoRA)
hf_prefill
=
torch
.
tensor
(
hf_outputs
.
top_input_logprobs
[
i
])
srt_prefill
=
torch
.
tensor
(
srt_outputs
.
top_input_logprobs
[
i
])
max_prefill_diff
=
torch
.
max
(
torch
.
abs
(
hf_prefill
-
srt_prefill
))
print
(
"Max prefill diff (HF vs SRT):"
,
max_prefill_diff
)
# Compare decode stage logprobs
hf_decode
=
torch
.
tensor
(
hf_outputs
.
top_output_logprobs
[
i
])
srt_decode
=
torch
.
tensor
(
srt_outputs
.
top_output_logprobs
[
i
])
max_decode_diff
=
torch
.
max
(
torch
.
abs
(
hf_decode
-
srt_decode
))
print
(
"Max decode diff (HF vs SRT):"
,
max_decode_diff
)
srt_output_str
=
srt_outputs
.
output_strs
[
i
].
strip
()
hf_output_str
=
hf_outputs
.
output_strs
[
i
].
strip
()
rouge_score
=
calculate_rouge_l
([
srt_output_str
],
[
hf_output_str
])[
0
]
print
(
"ROUGE-L score:"
,
rouge_score
)
print
(
"SRT output:"
,
srt_output_str
)
print
(
"HF output:"
,
hf_output_str
)
# Additional: compare prefill outputs between base model (no LoRA) and LoRA model for reference
hf_no_lora_prefill
=
torch
.
tensor
(
hf_no_lora_outputs
.
top_input_logprobs
[
i
])
srt_no_lora_prefill
=
torch
.
tensor
(
srt_no_lora_outputs
.
top_input_logprobs
[
i
]
)
print
(
"Max diff (SRT base vs SRT LoRA prefill):"
,
torch
.
max
(
torch
.
abs
(
srt_no_lora_prefill
-
srt_prefill
)),
)
print
(
"Max diff (HF base vs HF LoRA prefill):"
,
torch
.
max
(
torch
.
abs
(
hf_no_lora_prefill
-
hf_prefill
)),
)
if
hf_prefill
.
shape
[
0
]
<=
100
:
assert
torch
.
all
(
torch
.
abs
(
hf_prefill
-
srt_prefill
)
<
prefill_tol
),
(
f
"Prefill logprobs mismatch for base '
{
base_path
}
', adaptor '
{
adaptor_names
}
', "
f
"backend '
{
backend
}
', prompt: '
{
prompts
[
0
][:
50
]
}
...'"
)
if
hf_decode
.
shape
[
0
]
<=
100
:
assert
torch
.
all
(
torch
.
abs
(
hf_decode
-
srt_decode
)
<
decode_tol
),
(
f
"Decode logprobs mismatch for base '
{
base_path
}
', adaptor '
{
adaptor_names
}
', "
f
"backend '
{
backend
}
', prompt: '
{
prompts
[
0
][:
50
]
}
...'"
)
if
rouge_score
<
rouge_tol
:
raise
AssertionError
(
f
"ROUGE-L score
{
rouge_score
}
below tolerance
{
rouge_tol
}
"
f
"for base '
{
base_path
}
', adaptor '
{
adaptor_names
}
', backend '
{
backend
}
', prompt: '
{
prompts
[
0
][:
50
]
}
...'"
)
def
_run_backend_on_model_cases
(
self
,
model_cases
:
List
[
LoRAModelCase
]):
def
_run_backend_on_model_cases
(
self
,
model_cases
:
List
[
LoRAModelCase
]):
for
model_case
in
model_cases
:
for
model_case
in
model_cases
:
...
...
test/srt/models/lora/utils.py
View file @
ef9a378a
...
@@ -31,8 +31,8 @@ class LoRAModelCase:
...
@@ -31,8 +31,8 @@ class LoRAModelCase:
base
:
str
base
:
str
adaptors
:
List
[
LoRAAdaptor
]
adaptors
:
List
[
LoRAAdaptor
]
tp_size
:
int
=
1
tp_size
:
int
=
1
prefill_tolerance
:
float
=
5
e-
2
prefill_tolerance
:
float
=
1
e-
1
decode_tolerance
:
float
=
5
e-
2
decode_tolerance
:
float
=
1
e-
1
rouge_l_tolerance
:
float
=
1.0
rouge_l_tolerance
:
float
=
1.0
max_loras_per_batch
:
int
=
1
max_loras_per_batch
:
int
=
1
skip_long_prompt
:
bool
=
False
skip_long_prompt
:
bool
=
False
...
...
test/srt/run_suite.py
View file @
ef9a378a
...
@@ -15,7 +15,7 @@ suites = {
...
@@ -15,7 +15,7 @@ suites = {
"per-commit"
:
[
"per-commit"
:
[
TestFile
(
"models/lora/test_lora.py"
,
76
),
TestFile
(
"models/lora/test_lora.py"
,
76
),
TestFile
(
"models/lora/test_lora_backend.py"
,
420
),
TestFile
(
"models/lora/test_lora_backend.py"
,
420
),
TestFile
(
"models/lora/test_multi_lora_backend.py"
,
1
),
TestFile
(
"models/lora/test_multi_lora_backend.py"
,
1
44
),
TestFile
(
"models/test_embedding_models.py"
,
119
),
TestFile
(
"models/test_embedding_models.py"
,
119
),
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_grok_models.py"
,
60
),
TestFile
(
"models/test_grok_models.py"
,
60
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment