Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9c064bf7
Unverified
Commit
9c064bf7
authored
Oct 06, 2024
by
Ying Sheng
Committed by
GitHub
Oct 06, 2024
Browse files
[LoRA, Performance] Speedup multi-LoRA serving - Step 1 (#1587)
parent
58d1082e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
32 deletions
+34
-32
benchmark/lora/launch_server.py
benchmark/lora/launch_server.py
+3
-9
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+13
-14
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+18
-9
No files found.
benchmark/lora/launch_server.py
View file @
9c064bf7
import
argparse
import
os
NUM_LORAS
=
12
8
NUM_LORAS
=
8
LORA_PATH
=
{
"base"
:
"mistralai/Mistral-7B-Instruct-v0.3"
,
"lora"
:
"/home/ying/test_lora"
,
...
...
@@ -11,12 +11,11 @@ LORA_PATH = {
def
launch_server
(
args
):
base_path
=
LORA_PATH
[
"base"
]
lora_path
=
LORA_PATH
[
"lora"
]
max_loras_per_batch
=
4
if
args
.
base_only
:
cmd
=
f
"python -m sglang.launch_server --model
{
base_path
}
"
cmd
=
f
"python
3
-m sglang.launch_server --model
{
base_path
}
"
else
:
cmd
=
f
"python -m sglang.launch_server --model
{
base_path
}
--lora-paths "
cmd
=
f
"python
3
-m sglang.launch_server --model
{
base_path
}
--lora-paths "
for
i
in
range
(
NUM_LORAS
):
lora_name
=
f
"lora
{
i
}
"
cmd
+=
f
"
{
lora_name
}
=
{
lora_path
}
"
...
...
@@ -29,11 +28,6 @@ def launch_server(args):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-loras"
,
type
=
int
,
default
=
128
,
)
parser
.
add_argument
(
"--base-only"
,
action
=
"store_true"
,
...
...
python/sglang/srt/lora/lora.py
View file @
9c064bf7
...
...
@@ -101,12 +101,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
)
->
None
:
super
().
__init__
(
base_layer
,
segment_gemm
,
lora_rank
,
scaling
)
def
set_lora_info
(
self
,
A_buffer
,
B_buffer
,
bs
,
se
q_lens
,
weight_indices
):
def
set_lora_info
(
self
,
A_buffer
,
B_buffer
,
bs
,
se
g_indptr
,
weight_indices
):
self
.
set_lora
=
True
self
.
A_buffer
=
A_buffer
self
.
B_buffer
=
B_buffer
self
.
bs
=
bs
self
.
se
q_lens
=
seq_lens
self
.
se
g_indptr
=
seg_indptr
self
.
weight_indices
=
weight_indices
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -115,11 +115,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights
=
self
.
A_buffer
,
batch_size
=
self
.
bs
,
weight_column_major
=
True
,
seg_
lens
=
self
.
se
q_lens
,
seg_
indptr
=
self
.
se
g_indptr
,
weight_indices
=
self
.
weight_indices
,
)
# FIXME
assert
lora_a_output
.
shape
[
-
1
]
==
self
.
lora_rank
*
2
lora_output
=
torch
.
empty_like
(
base_output
)
output_dim
=
lora_output
.
shape
[
-
1
]
//
2
for
i
in
range
(
2
):
...
...
@@ -132,7 +131,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights
=
self
.
B_buffer
[:,
left
:
right
,
:].
contiguous
(),
batch_size
=
self
.
bs
,
weight_column_major
=
True
,
seg_
lens
=
self
.
se
q_lens
,
seg_
indptr
=
self
.
se
g_indptr
,
weight_indices
=
self
.
weight_indices
,
)
return
base_output
+
lora_output
*
self
.
scaling
...
...
@@ -145,14 +144,14 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
super
().
__init__
(
base_layer
,
segment_gemm
,
lora_rank
,
scaling
)
def
set_lora_info
(
self
,
A_buffer_qkv
,
B_buffer_q
,
B_buffer_kv
,
bs
,
se
q_lens
,
weight_indices
self
,
A_buffer_qkv
,
B_buffer_q
,
B_buffer_kv
,
bs
,
se
g_indptr
,
weight_indices
):
self
.
set_lora
=
True
self
.
A_buffer_qkv
=
A_buffer_qkv
self
.
B_buffer_q
=
B_buffer_q
self
.
B_buffer_kv
=
B_buffer_kv
self
.
bs
=
bs
self
.
se
q_lens
=
seq_lens
self
.
se
g_indptr
=
seg_indptr
self
.
weight_indices
=
weight_indices
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -161,7 +160,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights
=
self
.
A_buffer_qkv
,
batch_size
=
self
.
bs
,
weight_column_major
=
True
,
seg_
lens
=
self
.
se
q_lens
,
seg_
indptr
=
self
.
se
g_indptr
,
weight_indices
=
self
.
weight_indices
,
)
# FIXME parallelize qkv
...
...
@@ -173,7 +172,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights
=
self
.
B_buffer_q
,
batch_size
=
self
.
bs
,
weight_column_major
=
True
,
seg_
lens
=
self
.
se
q_lens
,
seg_
indptr
=
self
.
se
g_indptr
,
weight_indices
=
self
.
weight_indices
,
)
# kv
...
...
@@ -189,7 +188,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
weights
=
self
.
B_buffer_kv
[:,
left
:
right
,
:].
contiguous
(),
batch_size
=
self
.
bs
,
weight_column_major
=
True
,
seg_
lens
=
self
.
se
q_lens
,
seg_
indptr
=
self
.
se
g_indptr
,
weight_indices
=
self
.
weight_indices
,
)
)
...
...
@@ -202,12 +201,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
)
->
None
:
super
().
__init__
(
base_layer
,
segment_gemm
,
lora_rank
,
scaling
)
def
set_lora_info
(
self
,
A_buffer
,
B_buffer
,
bs
,
se
q_lens
,
weight_indices
):
def
set_lora_info
(
self
,
A_buffer
,
B_buffer
,
bs
,
se
g_indptr
,
weight_indices
):
self
.
set_lora
=
True
self
.
A_buffer
=
A_buffer
self
.
B_buffer
=
B_buffer
self
.
bs
=
bs
self
.
se
q_lens
=
seq_lens
self
.
se
g_indptr
=
seg_indptr
self
.
weight_indices
=
weight_indices
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -216,7 +215,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
weights
=
self
.
A_buffer
,
batch_size
=
self
.
bs
,
weight_column_major
=
True
,
seg_
lens
=
self
.
se
q_lens
,
seg_
indptr
=
self
.
se
g_indptr
,
weight_indices
=
self
.
weight_indices
,
)
lora_output
=
self
.
segment_gemm
.
run
(
...
...
@@ -224,7 +223,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
weights
=
self
.
B_buffer
,
batch_size
=
self
.
bs
,
weight_column_major
=
True
,
seg_
lens
=
self
.
se
q_lens
,
seg_
indptr
=
self
.
se
g_indptr
,
weight_indices
=
self
.
weight_indices
,
)
return
base_output
+
lora_output
*
self
.
scaling
...
...
python/sglang/srt/lora/lora_manager.py
View file @
9c064bf7
...
...
@@ -274,18 +274,24 @@ class LoRAManager:
cur_uids
=
set
(
forward_batch
.
lora_paths
)
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
i
=
0
j
=
len
(
self
.
active_uids
)
evictable_uids
=
list
(
self
.
active_uids
)
for
uid
in
cur_uids
:
if
uid
not
in
self
.
active_uids
:
while
i
<
len
(
evictable_uids
)
and
evictable_uids
[
i
]
in
cur_uids
:
i
+=
1
if
i
<
len
(
evictable_uids
):
if
j
<
self
.
max_loras_per_batch
:
index
=
j
j
+=
1
else
:
while
i
<
len
(
evictable_uids
)
and
evictable_uids
[
i
]
in
cur_uids
:
i
+=
1
assert
i
<
len
(
evictable_uids
)
self
.
active_uids
.
remove
(
evictable_uids
[
i
])
self
.
buffer_id
.
pop
(
evictable_uids
[
i
])
self
.
load_lora
(
uid
,
i
)
index
=
i
i
+=
1
self
.
load_lora
(
uid
,
index
)
self
.
active_uids
.
add
(
uid
)
self
.
buffer_id
[
uid
]
=
i
i
+=
1
self
.
buffer_id
[
uid
]
=
index
if
cur_uids
==
set
([
None
]):
return
...
...
@@ -295,8 +301,11 @@ class LoRAManager:
seg_lens
=
(
forward_batch
.
extend_seq_lens
if
forward_batch
.
forward_mode
.
is_extend
()
else
torch
.
ones
(
bs
)
else
torch
.
ones
(
bs
,
device
=
"cuda"
)
)
# FIXME: reuse the data rather than recompute
seg_indptr
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
seg_indptr
[
1
:]
=
torch
.
cumsum
(
seg_lens
,
dim
=
0
)
weight_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
for
i
,
lora_path
in
enumerate
(
forward_batch
.
lora_paths
):
weight_indices
[
i
]
=
self
.
buffer_id
[
lora_path
]
...
...
@@ -310,7 +319,7 @@ class LoRAManager:
self
.
A_buffer
[
weight_name
][
layer_id
],
self
.
B_buffer
[
weight_name
][
layer_id
],
bs
,
seg_
lens
,
seg_
indptr
,
weight_indices
,
)
else
:
...
...
@@ -319,6 +328,6 @@ class LoRAManager:
self
.
B_buffer
[
"q_proj"
][
layer_id
],
self
.
B_buffer
[
"kv_proj"
][
layer_id
],
bs
,
seg_
lens
,
seg_
indptr
,
weight_indices
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment