Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c45cab1c
"torchvision/transforms/_functional_tensor.py" did not exist on "00c119c853a74848655799c9b185cedf7a01f891"
Unverified
Commit
c45cab1c
authored
Feb 09, 2025
by
Baizhou Zhang
Committed by
GitHub
Feb 10, 2025
Browse files
[Fix] Fix accuracy bug and refactor codes for lora (#3413)
parent
27c4c9cf
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1137 additions
and
631 deletions
+1137
-631
python/sglang/srt/lora/backend/__init__.py
python/sglang/srt/lora/backend/__init__.py
+25
-5
python/sglang/srt/lora/backend/base_backend.py
python/sglang/srt/lora/backend/base_backend.py
+31
-9
python/sglang/srt/lora/backend/flashinfer_backend.py
python/sglang/srt/lora/backend/flashinfer_backend.py
+41
-4
python/sglang/srt/lora/backend/triton_backend.py
python/sglang/srt/lora/backend/triton_backend.py
+34
-4
python/sglang/srt/lora/layers.py
python/sglang/srt/lora/layers.py
+293
-0
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+101
-326
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+101
-269
python/sglang/srt/lora/mem_pool.py
python/sglang/srt/lora/mem_pool.py
+174
-0
python/sglang/srt/lora/triton_ops/__init__.py
python/sglang/srt/lora/triton_ops/__init__.py
+7
-1
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
+170
-0
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
+5
-5
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
+2
-2
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
+2
-2
python/sglang/srt/lora/utils.py
python/sglang/srt/lora/utils.py
+141
-0
test/srt/models/test_lora_backend.py
test/srt/models/test_lora_backend.py
+10
-4
No files found.
python/sglang/srt/lora/backend/__init__.py
View file @
c45cab1c
from
.base_backend
import
BaseLoraBackend
from
.flashinfer_backend
import
FlashInferLoraBackend
from
.triton_backend
import
TritonLoraBackend
from
.base_backend
import
BaseLoRABackend
from
.flashinfer_backend
import
FlashInferLoRABackend
from
.triton_backend
import
TritonLoRABackend
def
get_backend_from_name
(
name
:
str
)
->
BaseLoRABackend
:
"""
Get corresponding backend class from backend's name
"""
backend_mapping
=
{
"triton"
:
TritonLoRABackend
,
"flashinfer"
:
FlashInferLoRABackend
,
}
if
name
in
backend_mapping
:
return
backend_mapping
[
name
]
raise
Exception
(
f
"No supported lora backend called
{
name
}
. It should be one of
{
list
(
backend_mapping
.
keys
())
}
"
)
__all__
=
[
"FlashInferLoraBackend"
,
"TritonLoraBackend"
,
"BaseLoRABackend"
,
"FlashInferLoRABackend"
,
"TritonLoRABackend"
,
"get_backend_from_name"
,
]
python/sglang/srt/lora/backend/base_backend.py
View file @
c45cab1c
...
...
@@ -2,7 +2,7 @@ from typing import Tuple, Union
import
torch
from
sglang.srt.lora.
lora
import
Lo
ra
BatchInfo
from
sglang.srt.lora.
utils
import
Lo
RA
BatchInfo
def
get_fuse_output_scaling_add_from_name
(
name
:
str
)
->
bool
:
...
...
@@ -13,7 +13,7 @@ def get_fuse_output_scaling_add_from_name(name: str) -> bool:
return
mapping
.
get
(
name
,
False
)
def
get_fuse_
qkv
_lora_b_from_name
(
name
:
str
)
->
bool
:
def
get_fuse_
stacked
_lora_b_from_name
(
name
:
str
)
->
bool
:
mapping
=
{
"triton"
:
True
,
"flashinfer"
:
False
,
...
...
@@ -21,7 +21,7 @@ def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
return
mapping
.
get
(
name
,
False
)
class
BaseLo
ra
Backend
:
class
BaseLo
RA
Backend
:
"""Base class for different Lora backends.
Each backend has its own implementation of Lora kernels.
...
...
@@ -32,11 +32,11 @@ class BaseLoraBackend:
and the operation of scaling and adding will be fused into kernel
"""
def
__init__
(
self
,
name
:
str
,
batch_info
:
Lo
ra
BatchInfo
=
None
):
def
__init__
(
self
,
name
:
str
,
batch_info
:
Lo
RA
BatchInfo
=
None
):
self
.
name
=
name
self
.
batch_info
=
batch_info
self
.
fuse_output_scaling_add
=
get_fuse_output_scaling_add_from_name
(
name
)
self
.
fuse_
qkv
_lora_b
=
get_fuse_
qkv
_lora_b_from_name
(
name
)
self
.
fuse_
stacked
_lora_b
=
get_fuse_
stacked
_lora_b_from_name
(
name
)
def
run_lora_a_sgemm
(
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
...
...
@@ -46,10 +46,11 @@ class BaseLoraBackend:
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
usually input_dim is much larger than r
Returns:
result with shape (s, r)
result with shape (s,
c *
r)
"""
pass
...
...
@@ -83,7 +84,7 @@ class BaseLoraBackend:
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
qkv_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
If passed in as a tuple of two tensors contain
ing
:
If passed in as a tuple of two tensors
, it should
contain:
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
Returns:
...
...
@@ -91,5 +92,26 @@ class BaseLoraBackend:
"""
pass
def
set_batch_info
(
self
,
batch_info
:
LoraBatchInfo
):
def
run_gate_up_lora
(
self
,
x
:
torch
.
Tensor
,
gate_up_lora_a
:
torch
.
Tensor
,
gate_up_lora_b
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
]],
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
gate_up_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
Returns:
result with shape (s, 2 * output_dim)
"""
pass
def
set_batch_info
(
self
,
batch_info
:
LoRABatchInfo
):
self
.
batch_info
=
batch_info
python/sglang/srt/lora/backend/flashinfer_backend.py
View file @
c45cab1c
...
...
@@ -2,17 +2,17 @@ from typing import Tuple
import
torch
from
sglang.srt.lora.backend
import
BaseLo
ra
Backend
from
sglang.srt.lora.
lora
import
Lo
ra
BatchInfo
from
sglang.srt.lora.backend
import
BaseLo
RA
Backend
from
sglang.srt.lora.
utils
import
Lo
RA
BatchInfo
from
sglang.srt.utils
import
is_flashinfer_available
if
is_flashinfer_available
():
from
flashinfer
import
SegmentGEMMWrapper
class
FlashInferLo
ra
Backend
(
BaseLo
ra
Backend
):
class
FlashInferLo
RA
Backend
(
BaseLo
RA
Backend
):
def
__init__
(
self
,
name
:
str
,
batch_info
:
Lo
ra
BatchInfo
=
None
):
def
__init__
(
self
,
name
:
str
,
batch_info
:
Lo
RA
BatchInfo
=
None
):
super
().
__init__
(
name
,
batch_info
)
# Set up SGemm Wrapper from flashinfer
...
...
@@ -55,6 +55,8 @@ class FlashInferLoraBackend(BaseLoraBackend):
**
kwargs
,
)
->
torch
.
Tensor
:
assert
isinstance
(
qkv_lora_b
,
tuple
)
and
len
(
qkv_lora_b
)
==
2
# Shape of lora_a_output: (s, 3 * r)
lora_a_output
=
self
.
run_lora_a_sgemm
(
x
=
x
,
weights
=
qkv_lora_a
)
...
...
@@ -89,3 +91,38 @@ class FlashInferLoraBackend(BaseLoraBackend):
)
return
lora_output
def
run_gate_up_lora
(
self
,
x
:
torch
.
Tensor
,
gate_up_lora_a
:
torch
.
Tensor
,
gate_up_lora_b
:
Tuple
[
torch
.
Tensor
],
*
args
,
**
kwargs
,
)
->
torch
.
Tensor
:
assert
isinstance
(
gate_up_lora_b
,
tuple
)
and
len
(
gate_up_lora_b
)
==
2
lora_rank
=
gate_up_lora_b
[
0
].
shape
[
-
1
]
output_dim
=
gate_up_lora_b
[
0
].
shape
[
-
2
]
# Shape of lora_a_output: (s, 2 * r)
lora_a_output
=
self
.
run_lora_a_sgemm
(
x
=
x
,
weights
=
gate_up_lora_a
)
lora_output
=
torch
.
empty
(
(
x
.
shape
[
0
],
2
*
output_dim
),
device
=
x
.
device
,
dtype
=
x
.
dtype
,
)
# Compute lora for gate and up proj respectively
lora_output
[:,
:
output_dim
]
=
self
.
run_lora_b_sgemm
(
x
=
lora_a_output
[:,
:
lora_rank
].
contiguous
(),
weights
=
gate_up_lora_b
[
0
],
)
lora_output
[:,
output_dim
:]
=
self
.
run_lora_b_sgemm
(
x
=
lora_a_output
[:,
lora_rank
:].
contiguous
(),
weights
=
gate_up_lora_b
[
1
],
)
return
lora_output
python/sglang/srt/lora/backend/triton_backend.py
View file @
c45cab1c
import
torch
from
sglang.srt.lora.backend
import
BaseLoraBackend
from
sglang.srt.lora.lora
import
LoraBatchInfo
from
sglang.srt.lora.backend
import
BaseLoRABackend
from
sglang.srt.lora.triton_ops
import
(
gate_up_lora_b_fwd
,
qkv_lora_b_fwd
,
sgemm_lora_a_fwd
,
sgemm_lora_b_fwd
,
)
from
sglang.srt.lora.utils
import
LoRABatchInfo
class
TritonLo
ra
Backend
(
BaseLo
ra
Backend
):
class
TritonLo
RA
Backend
(
BaseLo
RA
Backend
):
def
__init__
(
self
,
name
:
str
,
batch_info
:
Lo
ra
BatchInfo
=
None
):
def
__init__
(
self
,
name
:
str
,
batch_info
:
Lo
RA
BatchInfo
=
None
):
super
().
__init__
(
name
,
batch_info
)
def
run_lora_a_sgemm
(
...
...
@@ -59,3 +60,32 @@ class TritonLoraBackend(BaseLoraBackend):
scaling
,
)
return
lora_output
def
run_gate_up_lora
(
self
,
x
:
torch
.
Tensor
,
gate_up_lora_a
:
torch
.
Tensor
,
gate_up_lora_b
:
torch
.
Tensor
,
base_output
:
torch
.
Tensor
=
None
,
scaling
:
float
=
1.0
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
# x: (s, input_dim)
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
assert
isinstance
(
gate_up_lora_b
,
torch
.
Tensor
)
output_dim
=
gate_up_lora_b
.
shape
[
-
2
]
//
2
# lora_a_output: (s, 2 * r)
lora_a_output
=
sgemm_lora_a_fwd
(
x
,
gate_up_lora_a
,
self
.
batch_info
)
lora_output
=
gate_up_lora_b_fwd
(
lora_a_output
,
gate_up_lora_b
,
self
.
batch_info
,
output_dim
,
base_output
,
scaling
,
)
return
lora_output
python/sglang/srt/lora/layers.py
0 → 100644
View file @
c45cab1c
import
torch
from
torch
import
nn
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.lora.backend
import
BaseLoRABackend
class
BaseLayerWithLoRA
(
nn
.
Module
):
def
__init__
(
self
,
base_layer
:
nn
.
Module
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
):
super
().
__init__
()
self
.
base_layer
:
nn
.
Module
=
base_layer
self
.
lora_rank
:
int
=
lora_rank
self
.
scaling
:
float
=
scaling
self
.
set_lora
:
bool
=
False
self
.
lora_backend
:
BaseLoRABackend
=
lora_backend
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
self
.
base_layer
.
forward
(
x
)
def
set_lora_info
(
self
,
*
args
):
pass
class
VocabParallelEmbeddingWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
VocabParallelEmbedding
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
self
.
weight
=
base_layer
.
weight
class
ColumnParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
ColumnParallelLinear
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer
:
torch
.
Tensor
,
B_buffer
:
torch
.
Tensor
,
):
self
.
set_lora
=
True
self
.
A_buffer
=
A_buffer
self
.
B_buffer
=
B_buffer
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
backend_kwargs
=
{
"base_output"
:
base_output
,
"scaling"
:
self
.
scaling
}
lora_a_output
=
self
.
lora_backend
.
run_lora_a_sgemm
(
x
,
self
.
A_buffer
)
lora_output
=
self
.
lora_backend
.
run_lora_b_sgemm
(
lora_a_output
,
self
.
B_buffer
[
0
],
**
backend_kwargs
,
)
return
(
lora_output
if
self
.
lora_backend
.
fuse_output_scaling_add
else
base_output
+
lora_output
*
self
.
scaling
)
def
forward
(
self
,
input_
:
torch
.
Tensor
):
# duplicate the logic in ColumnParallelLinear
bias
=
self
.
base_layer
.
bias
if
not
self
.
base_layer
.
skip_bias_add
else
None
output_parallel
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
input_
,
bias
)
if
self
.
set_lora
:
output_parallel
=
self
.
apply_lora
(
output_parallel
,
input_
)
if
self
.
base_layer
.
gather_output
:
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
base_layer
.
bias
if
self
.
base_layer
.
skip_bias_add
else
None
return
output
,
output_bias
class
MergedColumnParallelLinearWithLoRA
(
ColumnParallelLinearWithLoRA
):
def
__init__
(
self
,
base_layer
:
MergedColumnParallelLinear
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer
:
torch
.
Tensor
,
B_buffer
:
torch
.
Tensor
,
):
self
.
set_lora
=
True
self
.
A_buffer_gate_up
=
A_buffer
if
self
.
lora_backend
.
fuse_stacked_lora_b
:
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
self
.
B_buffer_gate_up
=
torch
.
cat
(
(
B_buffer
[
0
],
B_buffer
[
1
]),
dim
=-
2
).
contiguous
()
else
:
self
.
B_buffer_gate_up
=
(
B_buffer
[
0
],
B_buffer
[
1
])
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
backend_kwargs
=
{
"base_output"
:
base_output
,
"scaling"
:
self
.
scaling
}
lora_output
=
self
.
lora_backend
.
run_gate_up_lora
(
x
,
self
.
A_buffer_gate_up
,
self
.
B_buffer_gate_up
,
**
backend_kwargs
,
)
return
(
lora_output
if
self
.
lora_backend
.
fuse_output_scaling_add
else
base_output
+
lora_output
*
self
.
scaling
)
class
QKVParallelLinearWithLoRA
(
ColumnParallelLinearWithLoRA
):
def
init__
(
self
,
base_layer
:
QKVParallelLinear
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer_qkv
:
torch
.
Tensor
,
B_buffer_q
:
torch
.
Tensor
,
B_buffer_kv
:
torch
.
Tensor
,
):
self
.
set_lora
=
True
self
.
A_buffer_qkv
=
A_buffer_qkv
if
self
.
lora_backend
.
fuse_stacked_lora_b
:
assert
(
B_buffer_q
.
shape
[
-
1
]
==
B_buffer_kv
.
shape
[
-
1
]
),
"The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q
,
output_dim_kv
=
B_buffer_q
.
shape
[
-
2
],
B_buffer_kv
.
shape
[
-
2
]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self
.
B_buffer_qkv
=
torch
.
cat
(
(
B_buffer_q
[
0
],
B_buffer_kv
[
0
],
B_buffer_kv
[
1
]),
dim
=-
2
).
contiguous
()
# Offsets of q/k/v in output dimension
self
.
output_offset
=
torch
.
tensor
(
[
0
,
output_dim_q
,
output_dim_q
+
output_dim_kv
,
output_dim_q
+
2
*
output_dim_kv
,
],
dtype
=
torch
.
int32
,
device
=
B_buffer_q
.
device
,
)
# For computing number of launched blocks
self
.
max_qkv_out_dim
=
max
(
output_dim_q
,
output_dim_kv
)
else
:
self
.
B_buffer_qkv
=
(
B_buffer_q
,
B_buffer_kv
,
)
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
backend_kwargs
=
{
"base_output"
:
base_output
,
"scaling"
:
self
.
scaling
}
if
self
.
lora_backend
.
fuse_stacked_lora_b
:
backend_kwargs
[
"output_offset"
]
=
self
.
output_offset
backend_kwargs
[
"max_qkv_out_dim"
]
=
self
.
max_qkv_out_dim
lora_output
=
self
.
lora_backend
.
run_qkv_lora
(
x
,
self
.
A_buffer_qkv
,
self
.
B_buffer_qkv
,
**
backend_kwargs
,
)
return
(
lora_output
if
self
.
lora_backend
.
fuse_output_scaling_add
else
base_output
+
lora_output
*
self
.
scaling
)
class
RowParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
RowParallelLinear
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer
:
torch
.
Tensor
,
B_buffer
:
torch
.
Tensor
):
self
.
set_lora
=
True
self
.
A_buffer
=
A_buffer
self
.
B_buffer
=
B_buffer
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
backend_kwargs
=
{
"base_output"
:
base_output
,
"scaling"
:
self
.
scaling
}
lora_a_output
=
self
.
lora_backend
.
run_lora_a_sgemm
(
x
,
self
.
A_buffer
)
lora_output
=
self
.
lora_backend
.
run_lora_b_sgemm
(
lora_a_output
,
self
.
B_buffer
[
0
],
**
backend_kwargs
,
)
return
(
lora_output
if
self
.
lora_backend
.
fuse_output_scaling_add
else
base_output
+
lora_output
*
self
.
scaling
)
def
forward
(
self
,
input_
:
torch
.
Tensor
):
# duplicate the logic in RowParallelLinear
if
self
.
base_layer
.
input_is_parallel
:
input_parallel
=
input_
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
base_layer
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
output_parallel
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
input_parallel
)
if
self
.
set_lora
:
output_parallel
=
self
.
apply_lora
(
output_parallel
,
input_parallel
)
if
self
.
base_layer
.
reduce_results
and
self
.
base_layer
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output_
=
output_parallel
if
not
self
.
base_layer
.
skip_bias_add
:
output
=
(
output_
+
self
.
base_layer
.
bias
if
self
.
base_layer
.
bias
is
not
None
else
output_
)
output_bias
=
None
else
:
output
=
output_
output_bias
=
self
.
base_layer
.
bias
return
output
,
output_bias
def
get_lora_layer
(
layer
:
nn
.
Module
,
lora_rank
:
int
,
scaling
:
int
,
lora_backend
:
BaseLoRABackend
)
->
BaseLayerWithLoRA
:
supported_layer_types
=
{
# the order matters
VocabParallelEmbedding
:
VocabParallelEmbeddingWithLoRA
,
QKVParallelLinear
:
QKVParallelLinearWithLoRA
,
MergedColumnParallelLinear
:
MergedColumnParallelLinearWithLoRA
,
ColumnParallelLinear
:
ColumnParallelLinearWithLoRA
,
RowParallelLinear
:
RowParallelLinearWithLoRA
,
}
for
src_layer_type
,
lora_layer_type
in
supported_layer_types
.
items
():
if
isinstance
(
layer
,
src_layer_type
):
# pylint: disable=unidiomatic-typecheck
ret
=
lora_layer_type
(
layer
,
lora_rank
,
scaling
,
lora_backend
)
return
ret
raise
Exception
(
f
"No corresponding LoRA layer supported for
{
type
(
layer
)
}
."
)
python/sglang/srt/lora/lora.py
View file @
c45cab1c
...
...
@@ -19,282 +19,25 @@
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
import
re
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
import
torch
from
torch
import
nn
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.hf_transformers_utils
import
AutoConfig
from
sglang.srt.lora.backend
import
BaseLoRABackend
from
sglang.srt.lora.lora_config
import
LoRAConfig
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
@
dataclass
class
LoraBatchInfo
:
# Batch size
bs
:
int
# Lengths of each sequence in shape (bs,)
seg_lens
:
torch
.
Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr
:
torch
.
Tensor
# Maximum sequence length of current batch
max_len
:
int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices
:
torch
.
Tensor
class
BaseLayerWithLoRA
(
nn
.
Module
):
def
__init__
(
self
,
base_layer
,
lora_rank
,
scaling
,
lora_backend
):
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
lora_rank
=
lora_rank
self
.
scaling
=
scaling
self
.
set_lora
=
False
self
.
lora_backend
=
lora_backend
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
self
.
base_layer
.
forward
(
x
)
def
set_lora_info
(
self
,
*
args
):
pass
class
VocabParallelEmbeddingWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
VocabParallelEmbedding
,
lora_rank
,
scaling
,
lora_backend
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
self
.
weight
=
base_layer
.
weight
class
ColumnParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
ColumnParallelLinear
,
lora_rank
,
scaling
,
lora_backend
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
apply_lora
(
self
,
output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# TODO
return
output
def
forward
(
self
,
input_
:
torch
.
Tensor
):
# duplicate the logic in ColumnParallelLinear
bias
=
self
.
base_layer
.
bias
if
not
self
.
base_layer
.
skip_bias_add
else
None
output_parallel
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
input_
,
bias
)
if
self
.
set_lora
:
output_parallel
=
self
.
apply_lora
(
output_parallel
,
input_
)
if
self
.
base_layer
.
gather_output
:
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
base_layer
.
bias
if
self
.
base_layer
.
skip_bias_add
else
None
return
output
,
output_bias
class
MergedColumnParallelLinearWithLoRA
(
ColumnParallelLinearWithLoRA
):
def
__init__
(
self
,
base_layer
:
MergedColumnParallelLinear
,
lora_rank
,
scaling
,
lora_backend
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer
,
B_buffer
,
):
self
.
set_lora
=
True
self
.
A_buffer
=
A_buffer
self
.
B_buffer
=
B_buffer
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
lora_a_output
=
self
.
lora_backend
.
run_lora_a_sgemm
(
x
=
x
,
weights
=
self
.
A_buffer
)
output_dim
=
base_output
.
shape
[
-
1
]
lora_output
=
torch
.
empty_like
(
base_output
)
lora_output
[:,
:
output_dim
]
=
self
.
lora_backend
.
run_lora_b_sgemm
(
x
=
lora_a_output
[:,
0
:
self
.
lora_rank
].
contiguous
(),
weights
=
self
.
B_buffer
[
0
],
)
lora_output
[:,
output_dim
:
2
*
output_dim
]
=
(
self
.
lora_backend
.
run_lora_b_sgemm
(
x
=
lora_a_output
[:,
self
.
lora_rank
:
2
*
self
.
lora_rank
].
contiguous
(),
weights
=
self
.
B_buffer
[
1
],
)
)
return
base_output
+
lora_output
*
self
.
scaling
class
QKVParallelLinearWithLoRA
(
ColumnParallelLinearWithLoRA
):
def
init__
(
self
,
base_layer
:
QKVParallelLinear
,
lora_rank
,
scaling
,
lora_backend
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer_qkv
,
B_buffer_q
,
B_buffer_kv
,
):
self
.
set_lora
=
True
self
.
A_buffer_qkv
=
A_buffer_qkv
if
self
.
lora_backend
.
fuse_qkv_lora_b
:
assert
(
B_buffer_q
.
shape
[
-
1
]
==
B_buffer_kv
.
shape
[
-
1
]
),
"The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q
,
output_dim_kv
=
B_buffer_q
.
shape
[
-
2
],
B_buffer_kv
.
shape
[
-
2
]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self
.
B_buffer_qkv
=
torch
.
cat
(
(
B_buffer_q
[
0
],
B_buffer_kv
[
0
],
B_buffer_kv
[
1
]),
dim
=-
2
).
contiguous
()
# Offsets of q/k/v in output dimension
self
.
output_offset
=
torch
.
tensor
(
[
0
,
output_dim_q
,
output_dim_q
+
output_dim_kv
,
output_dim_q
+
2
*
output_dim_kv
,
],
dtype
=
torch
.
int32
,
device
=
B_buffer_q
.
device
,
)
# For computing number of launched blocks
self
.
max_qkv_out_dim
=
max
(
output_dim_q
,
output_dim_kv
)
else
:
self
.
B_buffer_qkv
=
(
B_buffer_q
,
B_buffer_kv
,
)
self
.
output_offset
=
None
self
.
max_qkv_out_dim
=
None
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
lora_output
=
self
.
lora_backend
.
run_qkv_lora
(
x
,
self
.
A_buffer_qkv
,
self
.
B_buffer_qkv
,
output_offset
=
self
.
output_offset
,
max_qkv_out_dim
=
self
.
max_qkv_out_dim
,
base_output
=
base_output
,
scaling
=
self
.
scaling
,
)
return
(
lora_output
if
self
.
lora_backend
.
fuse_output_scaling_add
else
base_output
+
lora_output
*
self
.
scaling
)
class
RowParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
RowParallelLinear
,
lora_rank
,
scaling
,
lora_backend
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer
,
B_buffer
):
self
.
set_lora
=
True
self
.
A_buffer
=
A_buffer
self
.
B_buffer
=
B_buffer
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
lora_a_output
=
self
.
lora_backend
.
run_lora_a_sgemm
(
x
,
self
.
A_buffer
)
lora_output
=
self
.
lora_backend
.
run_lora_b_sgemm
(
lora_a_output
,
self
.
B_buffer
[
0
],
base_output
=
base_output
,
scaling
=
self
.
scaling
,
)
return
(
lora_output
if
self
.
lora_backend
.
fuse_output_scaling_add
else
base_output
+
lora_output
*
self
.
scaling
)
def
forward
(
self
,
input_
):
# duplicate the logic in RowParallelLinear
if
self
.
base_layer
.
input_is_parallel
:
input_parallel
=
input_
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
base_layer
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
output_parallel
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
input_parallel
)
if
self
.
set_lora
:
output_parallel
=
self
.
apply_lora
(
output_parallel
,
input_parallel
)
if
self
.
base_layer
.
reduce_results
and
self
.
base_layer
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output_
=
output_parallel
if
not
self
.
base_layer
.
skip_bias_add
:
output
=
(
output_
+
self
.
base_layer
.
bias
if
self
.
base_layer
.
bias
is
not
None
else
output_
)
output_bias
=
None
else
:
output
=
output_
output_bias
=
self
.
base_layer
.
bias
return
output
,
output_bias
def
get_lora_layer
(
layer
:
nn
.
Module
,
lora_rank
,
scaling
,
lora_backend
)
->
BaseLayerWithLoRA
:
supported_layer_types
=
{
# the order matters
VocabParallelEmbedding
:
VocabParallelEmbeddingWithLoRA
,
QKVParallelLinear
:
QKVParallelLinearWithLoRA
,
MergedColumnParallelLinear
:
MergedColumnParallelLinearWithLoRA
,
ColumnParallelLinear
:
ColumnParallelLinearWithLoRA
,
RowParallelLinear
:
RowParallelLinearWithLoRA
,
}
for
src_layer_type
,
lora_layer_type
in
supported_layer_types
.
items
():
if
isinstance
(
layer
,
src_layer_type
):
# pylint: disable=unidiomatic-typecheck
ret
=
lora_layer_type
(
layer
,
lora_rank
,
scaling
,
lora_backend
)
return
ret
raise
Exception
(
f
"No corresponding LoRA layer supported for
{
type
(
layer
)
}
."
)
def
get_mapped_params
(
module_names
):
ret
=
set
()
for
module_name
in
module_names
:
ret
.
add
(
params_mapping
(
module_name
))
return
list
(
ret
)
class
LoRALayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
base_hf_c
onfig
):
def
__init__
(
self
,
config
:
LoRAConfig
,
base_hf_config
:
AutoC
onfig
):
super
().
__init__
()
self
.
config
=
config
self
.
base_hf_config
=
base_hf_config
self
.
weights
=
{}
self
.
weight_gpu
=
{}
self
.
config
:
LoRAConfig
=
config
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
self
.
weights
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
weight_gpu
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
def
load_to_gpu
(
self
):
for
name
,
weight
in
self
.
weights
.
items
():
...
...
@@ -306,33 +49,32 @@ class LoRALayer(nn.Module):
class
LoRAAdapter
(
nn
.
Module
):
def
__init__
(
self
,
uid
,
config
,
base_hf_config
,
load_config
,
lora_backend
):
def
__init__
(
self
,
uid
:
str
,
config
:
LoRAConfig
,
base_hf_config
:
AutoConfig
,
load_config
:
LoadConfig
,
lora_backend
:
BaseLoRABackend
,
):
super
().
__init__
()
self
.
uid
=
uid
self
.
config
=
config
self
.
uid
:
str
=
uid
self
.
config
:
LoRAConfig
=
config
assert
self
.
config
.
hf_config
[
"peft_type"
].
lower
()
==
"lora"
self
.
base_hf_config
=
base_hf_config
self
.
load_config
=
load_config
self
.
lora_backend
=
lora_backend
self
.
scaling
=
self
.
config
.
lora_alpha
/
self
.
config
.
r
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
self
.
load_config
:
LoadConfig
=
load_config
self
.
lora_backend
:
BaseLoRABackend
=
lora_backend
self
.
scaling
:
float
=
self
.
config
.
lora_alpha
/
self
.
config
.
r
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
:
List
[
LoRALayer
]
=
nn
.
ModuleList
(
[
LoRALayer
(
config
,
base_hf_config
)
for
i
in
range
(
base_hf_config
.
num_hidden_layers
)
]
)
self
.
weights
=
{}
self
.
weights_gpu
=
{}
def
get_stacked_multiply
(
self
,
module_name
):
stacked_rank
=
{
"qkv_proj"
:
3
,
"kv_proj"
:
2
,
"gate_up_proj"
:
2
,
}
return
stacked_rank
[
module_name
]
if
module_name
in
stacked_rank
else
1
self
.
weights
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
weights_gpu
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
def
load_to_gpu
(
self
):
for
name
,
weight
in
self
.
weights
.
items
():
...
...
@@ -367,44 +109,77 @@ class LoRAAdapter(nn.Module):
for
i
in
range
(
self
.
base_hf_config
.
num_hidden_layers
):
layer
=
self
.
layers
[
i
]
weight_names
=
[
name
for
name
,
_
in
layer
.
weights
.
items
()]
for
weight_name
in
weight_names
:
if
"k_proj"
in
weight_name
:
q_name
=
weight_name
.
replace
(
"k_proj"
,
"q_proj"
)
v_name
=
weight_name
.
replace
(
"k_proj"
,
"v_proj"
)
kv_name
=
weight_name
.
replace
(
"k_proj"
,
"kv_proj"
)
qkv_name
=
weight_name
.
replace
(
"k_proj"
,
"qkv_proj"
)
if
"lora_A"
in
weight_name
:
layer
.
weights
[
qkv_name
]
=
torch
.
cat
(
(
layer
.
weights
[
q_name
],
layer
.
weights
[
weight_name
],
layer
.
weights
[
v_name
],
),
0
,
)
layer
.
weights
.
pop
(
q_name
)
layer
.
weights
.
pop
(
weight_name
)
layer
.
weights
.
pop
(
v_name
)
else
:
layer
.
weights
[
kv_name
]
=
torch
.
stack
(
[
layer
.
weights
[
weight_name
],
layer
.
weights
[
v_name
],
],
dim
=
0
,
)
layer
.
weights
.
pop
(
weight_name
)
layer
.
weights
.
pop
(
v_name
)
elif
"gate_proj"
in
weight_name
:
up_name
=
weight_name
.
replace
(
"gate_proj"
,
"up_proj"
)
gate_up_name
=
weight_name
.
replace
(
"gate_proj"
,
"gate_up_proj"
)
if
"lora_A"
in
weight_name
:
layer
.
weights
[
gate_up_name
]
=
torch
.
cat
(
(
layer
.
weights
[
weight_name
],
layer
.
weights
[
up_name
]),
0
)
else
:
layer
.
weights
[
gate_up_name
]
=
torch
.
stack
(
[
layer
.
weights
[
weight_name
],
layer
.
weights
[
up_name
]],
dim
=
0
)
layer
.
weights
.
pop
(
weight_name
)
layer
.
weights
.
pop
(
up_name
)
self
.
stack_qkv_proj
(
weight_names
,
layer
.
weights
)
self
.
stack_gate_up_proj
(
weight_names
,
layer
.
weights
)
def
stack_qkv_proj
(
self
,
weight_names
:
List
[
str
],
weights
:
Dict
[
str
,
torch
.
Tensor
]):
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
target_module
=
set
()
for
weight_name
in
weight_names
:
if
"k_proj"
in
weight_name
:
target_module
.
add
(
"k_proj"
)
if
"q_proj"
in
weight_name
:
target_module
.
add
(
"q_proj"
)
if
"v_proj"
in
weight_name
:
target_module
.
add
(
"v_proj"
)
if
len
(
target_module
)
==
0
:
return
for
weight_name
in
weight_names
:
# We assume every lora adaptor should contain lora modules for q_proj
if
"q_proj"
in
weight_name
:
q_name
=
weight_name
k_name
=
weight_name
.
replace
(
"q_proj"
,
"k_proj"
)
v_name
=
weight_name
.
replace
(
"q_proj"
,
"v_proj"
)
kv_name
=
weight_name
.
replace
(
"q_proj"
,
"kv_proj"
)
qkv_name
=
weight_name
.
replace
(
"q_proj"
,
"qkv_proj"
)
# If k_proj doesn't have lora, initialize it to zero
k_proj_weight
=
(
weights
[
k_name
]
if
"k_proj"
in
target_module
else
torch
.
zeros_like
(
weights
[
v_name
])
)
if
"lora_A"
in
weight_name
:
weights
[
qkv_name
]
=
torch
.
cat
(
(
weights
[
q_name
],
k_proj_weight
,
weights
[
v_name
],
),
0
,
)
weights
.
pop
(
q_name
)
if
"k_proj"
in
target_module
:
weights
.
pop
(
k_name
)
weights
.
pop
(
v_name
)
else
:
weights
[
kv_name
]
=
torch
.
stack
(
[
k_proj_weight
,
weights
[
v_name
],
],
dim
=
0
,
)
if
"k_proj"
in
target_module
:
weights
.
pop
(
k_name
)
weights
.
pop
(
v_name
)
def
stack_gate_up_proj
(
self
,
weight_names
:
List
[
str
],
weights
:
Dict
[
str
,
torch
.
Tensor
]
):
for
weight_name
in
weight_names
:
if
"gate_proj"
in
weight_name
:
up_name
=
weight_name
.
replace
(
"gate_proj"
,
"up_proj"
)
gate_up_name
=
weight_name
.
replace
(
"gate_proj"
,
"gate_up_proj"
)
if
"lora_A"
in
weight_name
:
weights
[
gate_up_name
]
=
torch
.
cat
(
(
weights
[
weight_name
],
weights
[
up_name
]),
0
)
else
:
weights
[
gate_up_name
]
=
torch
.
stack
(
[
weights
[
weight_name
],
weights
[
up_name
]],
dim
=
0
)
weights
.
pop
(
weight_name
)
weights
.
pop
(
up_name
)
python/sglang/srt/lora/lora_manager.py
View file @
c45cab1c
...
...
@@ -16,307 +16,115 @@
# and "Punica: Multi-Tenant LoRA Serving"
import
logging
import
r
e
from
typing
import
Dict
,
List
,
Set
,
Tupl
e
import
torch
from
sglang.srt.lora.backend
import
FlashInferLoraBackend
,
TritonLoraBackend
from
sglang.srt.lora.lora
import
LoRAAdapter
,
LoraBatchInfo
,
get_lora_layer
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.hf_transformers_utils
import
AutoConfig
from
sglang.srt.lora.backend
import
BaseLoRABackend
,
get_backend_from_name
from
sglang.srt.lora.layers
import
get_lora_layer
from
sglang.srt.lora.lora
import
LoRAAdapter
from
sglang.srt.lora.lora_config
import
LoRAConfig
from
sglang.srt.lora.mem_pool
import
LoRAMemoryPool
from
sglang.srt.lora.utils
import
(
LoRABatchInfo
,
LoRAType
,
get_customized_names_from_hf_names
,
get_layer_id
,
get_stacked_name
,
get_weight_name
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
is_flashinfer_available
,
replace_submodule
from
sglang.srt.utils
import
replace_submodule
logger
=
logging
.
getLogger
(
__name__
)
def
get_module_name
(
name
):
# Fallback solution of mapping from config module name to module name in model class.
# Please check if it aligns with your base model.
# Please implement the function in the model class if it is not.
# You can reference this function in llama.py.
params_mapping
=
{
"q_proj"
:
"qkv_proj"
,
"k_proj"
:
"qkv_proj"
,
"v_proj"
:
"qkv_proj"
,
"gate_proj"
:
"gate_up_proj"
,
"up_proj"
:
"gate_up_proj"
,
}
return
params_mapping
.
get
(
name
,
name
)
def
get_hidden_dim
(
module_name
,
config
):
# Fallback solution of get_hidden_dim for different modules
# Please check if it aligns with your base model.
# Please implement the function in the model class if it is not.
# You can reference this function in llama.py.
if
module_name
in
[
"q_proj"
,
"o_proj"
,
"qkv_proj"
]:
return
config
.
hidden_size
,
config
.
hidden_size
elif
module_name
in
[
"kv_proj"
]:
return
config
.
hidden_size
,
config
.
hidden_size
//
(
config
.
num_attention_heads
//
config
.
num_key_value_heads
)
elif
module_name
==
"gate_up_proj"
:
return
config
.
hidden_size
,
config
.
intermediate_size
elif
module_name
==
"down_proj"
:
return
config
.
intermediate_size
,
config
.
hidden_size
else
:
raise
NotImplementedError
()
def
get_stacked_name
(
name
):
# origin name -> (name for A, name for B)
params_mapping
=
{
"q_proj"
:
(
"qkv_proj"
,
"q_proj"
),
"k_proj"
:
(
"qkv_proj"
,
"kv_proj"
),
"v_proj"
:
(
"qkv_proj"
,
"kv_proj"
),
"gate_proj"
:
(
"gate_up_proj"
,
"gate_up_proj"
),
"up_proj"
:
(
"gate_up_proj"
,
"gate_up_proj"
),
}
return
params_mapping
.
get
(
name
,
(
name
,
name
))
def
get_backend_from_name
(
name
):
backend_mapping
=
{
"triton"
:
TritonLoraBackend
,
"flashinfer"
:
FlashInferLoraBackend
,
}
if
name
in
backend_mapping
:
return
backend_mapping
[
name
]
raise
Exception
(
f
"No supported lora backend called
{
name
}
. It should be one of
{
list
(
backend_mapping
.
keys
())
}
"
)
def
get_layer_id
(
name
):
match
=
re
.
search
(
r
"layers\.(\d+)\."
,
name
)
if
match
is
None
:
return
None
return
int
(
match
.
group
(
1
))
class
LoRAManager
:
def
__init__
(
self
,
base_model
,
lora_paths
,
base_hf_config
,
max_loras_per_batch
,
load_config
,
dtype
,
lora_backend
,
base_model
:
torch
.
nn
.
Module
,
lora_paths
:
Dict
[
str
,
str
]
,
base_hf_config
:
AutoConfig
,
max_loras_per_batch
:
int
,
load_config
:
LoadConfig
,
dtype
:
torch
.
dtype
,
lora_backend
:
str
=
"triton"
,
):
self
.
base_model
=
base_model
self
.
lora_paths
=
lora_paths
self
.
base_hf_config
=
base_hf_config
self
.
max_loras_per_batch
=
max_loras_per_batch
self
.
load_config
=
load_config
self
.
dtype
=
dtype
logger
.
info
(
f
"Using
{
lora_backend
}
as backend of Lora kernels."
)
self
.
base_model
:
torch
.
nn
.
Module
=
base_model
self
.
lora_paths
:
Dict
[
str
,
str
]
=
lora_paths
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
self
.
max_loras_per_batch
:
int
=
max_loras_per_batch
self
.
load_config
:
LoadConfig
=
load_config
self
.
dtype
:
torch
.
dtype
=
dtype
# LoRA backend for running sgemm kernels
logger
.
info
(
f
"Using
{
lora_backend
}
as backend of LoRA kernels."
)
backend_type
=
get_backend_from_name
(
lora_backend
)
self
.
lora_backend
=
backend_type
(
lora_backend
)
self
.
lora_backend
:
BaseLoRABackend
=
backend_type
(
lora_backend
)
self
.
init_loras
()
self
.
init_lora_memory_pool
()
self
.
init_lora_batch
()
def
match_target_modules
(
self
,
module_name
):
for
target_module
in
self
.
target_modules
:
if
module_name
.
split
(
"."
)[
-
1
]
==
target_module
:
return
True
return
False
def
get_target_modules
(
self
):
modules
=
[]
for
module_name
,
module
in
self
.
base_model
.
named_modules
():
if
self
.
match_target_modules
(
module_name
):
modules
.
append
((
module_name
,
module
))
return
modules
def
set_lora_module
(
self
,
module_name
,
module
):
lora_module
=
get_lora_layer
(
module
,
self
.
max_lora_dim
,
self
.
scaling
,
self
.
lora_backend
)
replace_submodule
(
self
.
base_model
,
module_name
,
lora_module
)
return
lora_module
def
init_loras
(
self
):
# get configs and target modules
self
.
configs
=
{}
self
.
origin_target_modules
=
set
()
# Config of each LoRA adapter
self
.
configs
:
Dict
[
str
,
LoRAConfig
]
=
{}
# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
self
.
hf_target_names
:
Set
[
str
]
=
set
()
for
name
,
path
in
self
.
lora_paths
.
items
():
self
.
configs
[
name
]
=
LoRAConfig
(
path
)
self
.
origin
_target_
modul
es
=
set
(
self
.
origin
_target_
modul
es
)
|
set
(
self
.
hf
_target_
nam
es
=
set
(
self
.
hf
_target_
nam
es
)
|
set
(
self
.
configs
[
name
].
target_modules
)
if
hasattr
(
self
.
base_model
,
"get_module_name"
):
self
.
target_modules
=
{
self
.
base_model
.
get_module_name
(
module
)
for
module
in
self
.
origin_target_modules
}
else
:
logger
.
warning
(
"WARNING: get_module_name() is not defined, "
"which is used to map config module name to model implementation module name."
"Use the default one, but please check if it is correct for your model."
)
self
.
target_modules
=
{
get_module_name
(
module
)
for
module
in
self
.
origin_target_modules
}
self
.
target_weights
=
set
(
[
get_stacked_name
(
module
)
for
module
in
self
.
origin_target_modules
]
# Target lora weight names for lora_a and lora_b modules repectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
self
.
lora_weight_names
:
Set
[
Tuple
[
str
]]
=
set
(
[
get_stacked_name
(
module
)
for
module
in
self
.
hf_target_names
]
)
# load all weights to cpu
self
.
loras
=
[]
self
.
lora_id
=
{}
self
.
loras
:
Dict
[
str
,
LoRAAdapter
]
=
{}
for
name
in
self
.
lora_paths
.
keys
():
self
.
lora_id
[
name
]
=
len
(
self
.
loras
)
self
.
loras
.
append
(
LoRAAdapter
(
name
,
self
.
configs
[
name
],
self
.
base_hf_config
,
self
.
load_config
,
self
.
lora_backend
,
)
lora_adapter
=
LoRAAdapter
(
name
,
self
.
configs
[
name
],
self
.
base_hf_config
,
self
.
load_config
,
self
.
lora_backend
,
)
self
.
loras
[
-
1
].
initialize_weights
()
lora_adapter
.
initialize_weights
()
self
.
loras
[
name
]
=
lora_adapter
# misc lora configs
self
.
max_lora_dim
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
self
.
scaling
=
self
.
loras
[
0
].
scaling
# FIXME remove the restrictions
# FIXME remove the restrictions after implementing unified paging
self
.
max_lora_dim
:
int
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
self
.
scaling
:
float
=
list
(
self
.
loras
.
values
())[
0
].
scaling
assert
all
(
x
.
hf_config
[
"r"
]
==
self
.
max_lora_dim
for
x
in
self
.
configs
.
values
())
assert
all
(
x
.
scaling
==
self
.
scaling
for
x
in
self
.
loras
)
assert
all
(
x
.
scaling
==
self
.
scaling
for
x
in
self
.
loras
.
values
()
)
# monkey patch to use the LoRA version
self
.
lora_modules
=
[]
for
module_name
,
module
in
self
.
get_target_modules
():
self
.
lora_modules
.
append
(
(
module_name
,
self
.
set_lora_module
(
module_name
,
module
))
)
# Convert original model layers to layers with LoRA
self
.
convert_to_lora_layers
()
def
init_lora_memory_pool
(
self
):
# preallocate lora memory pool
self
.
A_buffer
=
{}
self
.
B_buffer
=
{}
num_layer
=
self
.
base_hf_config
.
num_hidden_layers
for
module_A
,
module_B
in
self
.
target_weights
:
# init A tensor, column_major=True
if
hasattr
(
self
.
base_model
,
"get_hidden_dim"
):
hidden_dim_A
,
_
=
self
.
base_model
.
get_hidden_dim
(
module_A
)
else
:
logger
.
warning
(
"WARNING: get_hidden_dim() is not defined, "
"which is used to get the hidden dim for different lora modules"
"Use the default one, but please check if it is correct for your model."
)
hidden_dim_A
,
_
=
get_hidden_dim
(
module_A
,
self
.
base_hf_config
)
c
=
self
.
loras
[
-
1
].
get_stacked_multiply
(
module_A
)
if
module_A
not
in
self
.
A_buffer
:
self
.
A_buffer
[
module_A
]
=
[
torch
.
empty
(
(
self
.
max_loras_per_batch
,
self
.
max_lora_dim
*
c
,
hidden_dim_A
,
),
dtype
=
self
.
dtype
,
device
=
"cuda"
,
)
for
i
in
range
(
num_layer
)
]
# init B tensor, column_major=True
if
hasattr
(
self
.
base_model
,
"get_hidden_dim"
):
_
,
hidden_dim_B
=
self
.
base_model
.
get_hidden_dim
(
module_B
)
else
:
logger
.
warning
(
"WARNING: get_hidden_dim() is not defined, "
"which is used to get the hidden dim for different lora modules"
"Use the default one, but please check if it is correct for your model."
)
_
,
hidden_dim_B
=
get_hidden_dim
(
module_B
,
self
.
base_hf_config
)
c
=
self
.
loras
[
-
1
].
get_stacked_multiply
(
module_B
)
if
module_B
not
in
self
.
B_buffer
:
self
.
B_buffer
[
module_B
]
=
[
torch
.
empty
(
(
c
,
self
.
max_loras_per_batch
,
hidden_dim_B
,
self
.
max_lora_dim
,
),
dtype
=
self
.
dtype
,
device
=
"cuda"
,
)
for
i
in
range
(
num_layer
)
]
def
init_lora_batch
(
self
):
self
.
active_uids
=
set
()
# set of active loras
self
.
buffer_id
=
{}
# lora uid -> idx in memory pool
def
get_weight_name
(
self
,
name
,
idx
):
for
target_weight_name
in
self
.
target_weights
:
if
target_weight_name
[
idx
]
in
name
:
return
target_weight_name
[
idx
]
def
load_lora
(
self
,
uid
,
buffer_id
):
num_layer
=
self
.
base_hf_config
.
num_hidden_layers
if
uid
is
None
:
for
i
in
range
(
num_layer
):
for
k
in
self
.
A_buffer
.
keys
():
self
.
A_buffer
[
k
][
i
][
buffer_id
]
*=
0
return
# Initialize memory pool
self
.
memory_pool
=
LoRAMemoryPool
(
self
.
base_hf_config
,
self
.
max_loras_per_batch
,
self
.
max_lora_dim
,
self
.
dtype
)
for
i
in
range
(
num_layer
):
layer_weights
=
self
.
loras
[
self
.
lora_id
[
uid
]].
layers
[
i
].
weights
for
name
,
weights
in
layer_weights
.
items
():
if
"lora_A"
in
name
:
lora_weight_name
=
self
.
get_weight_name
(
name
,
0
)
if
lora_weight_name
:
self
.
A_buffer
[
lora_weight_name
][
i
][
buffer_id
].
copy_
(
weights
)
else
:
lora_weight_name
=
self
.
get_weight_name
(
name
,
1
)
if
lora_weight_name
:
c
=
self
.
loras
[
-
1
].
get_stacked_multiply
(
lora_weight_name
)
if
c
>
1
:
for
j
in
range
(
c
):
self
.
B_buffer
[
lora_weight_name
][
i
][
j
][
buffer_id
].
copy_
(
weights
[
j
]
)
else
:
self
.
B_buffer
[
lora_weight_name
][
i
][
0
][
buffer_id
].
copy_
(
weights
)
# Initialize target lora modules in memory pool
self
.
memory_pool
.
init_buffers
(
self
.
lora_weight_names
,
self
.
base_model
)
def
prepare_lora_batch
(
self
,
forward_batch
:
ForwardBatch
):
# load active loras into lora memory pool
cur_uids
=
set
(
forward_batch
.
lora_paths
)
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
i
=
0
j
=
len
(
self
.
active_uids
)
evictable_uids
=
list
(
self
.
active_uids
)
for
uid
in
cur_uids
:
if
uid
not
in
self
.
active_uids
:
if
j
<
self
.
max_loras_per_batch
:
index
=
j
j
+=
1
else
:
while
i
<
len
(
evictable_uids
)
and
evictable_uids
[
i
]
in
cur_uids
:
i
+=
1
assert
i
<
len
(
evictable_uids
)
self
.
active_uids
.
remove
(
evictable_uids
[
i
])
self
.
buffer_id
.
pop
(
evictable_uids
[
i
])
index
=
i
i
+=
1
self
.
load_lora
(
uid
,
index
)
self
.
active_uids
.
add
(
uid
)
self
.
buffer_id
[
uid
]
=
index
self
.
memory_pool
.
prepare_lora_batch
(
cur_uids
,
self
.
loras
)
# FIXME: Handle lora uid with None more safely
if
cur_uids
==
set
([
None
]):
return
...
...
@@ -332,9 +140,9 @@ class LoRAManager:
max_len
=
int
(
torch
.
max
(
seg_lens
))
weight_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
for
i
,
lora_path
in
enumerate
(
forward_batch
.
lora_paths
):
weight_indices
[
i
]
=
self
.
buffer_id
[
lora_path
]
weight_indices
[
i
]
=
self
.
memory_pool
.
get_
buffer_id
(
lora_path
)
batch_info
=
Lo
ra
BatchInfo
(
batch_info
=
Lo
RA
BatchInfo
(
bs
=
bs
,
seg_lens
=
seg_lens
,
seg_indptr
=
seg_indptr
,
...
...
@@ -346,16 +154,40 @@ class LoRAManager:
# call set_lora_info for each lora modules
for
module_name
,
module
in
self
.
lora_modules
:
layer_id
=
get_layer_id
(
module_name
)
if
"qkv_proj"
not
in
module_name
:
weight_name
=
self
.
get_weight_name
(
module_name
,
0
)
weight_name
=
get_weight_name
(
module_name
,
self
.
lora_weight_names
,
LoRAType
.
LORA_A
)
module
.
set_lora_info
(
self
.
A_buffer
[
weight_name
][
layer_id
]
,
self
.
B_buffer
[
weight_name
][
layer_id
]
,
self
.
memory_pool
.
get_tensor
(
weight_name
,
layer_id
,
LoRAType
.
LORA_A
)
,
self
.
memory_pool
.
get_tensor
(
weight_name
,
layer_id
,
LoRAType
.
LORA_B
)
,
)
else
:
module
.
set_lora_info
(
self
.
A_buffer
[
"qkv_proj"
][
layer_id
],
self
.
B_buffer
[
"q_proj"
][
layer_id
],
self
.
B_buffer
[
"kv_proj"
][
layer_id
],
self
.
memory_pool
.
get_tensor
(
"qkv_proj"
,
layer_id
,
LoRAType
.
LORA_A
),
self
.
memory_pool
.
get_tensor
(
"q_proj"
,
layer_id
,
LoRAType
.
LORA_B
),
self
.
memory_pool
.
get_tensor
(
"kv_proj"
,
layer_id
,
LoRAType
.
LORA_B
),
)
def
set_lora_module
(
self
,
module_name
,
module
):
lora_module
=
get_lora_layer
(
module
,
self
.
max_lora_dim
,
self
.
scaling
,
self
.
lora_backend
)
replace_submodule
(
self
.
base_model
,
module_name
,
lora_module
)
return
lora_module
def
convert_to_lora_layers
(
self
):
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names
=
get_customized_names_from_hf_names
(
self
.
hf_target_names
,
self
.
base_model
)
# Monkey patch to use the LoRA version layers
self
.
lora_modules
:
List
[
Tuple
[
str
,
torch
.
nn
.
Module
]]
=
[]
for
module_name
,
module
in
self
.
base_model
.
named_modules
():
# The module should be converted if it is included in target_names
if
module_name
.
split
(
"."
)[
-
1
]
in
customized_target_names
:
self
.
lora_modules
.
append
(
(
module_name
,
self
.
set_lora_module
(
module_name
,
module
))
)
python/sglang/srt/lora/mem_pool.py
0 → 100644
View file @
c45cab1c
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
torch
from
sglang.srt.hf_transformers_utils
import
AutoConfig
from
sglang.srt.lora.lora
import
LoRAAdapter
from
sglang.srt.lora.utils
import
(
LoRAType
,
get_hidden_dim
,
get_stacked_multiply
,
get_weight_name
,
)
class
LoRAMemoryPool
:
"""Class for memory pool management of lora modules"""
def
__init__
(
self
,
base_hf_config
:
AutoConfig
,
max_loras_per_batch
:
int
,
max_lora_dim
:
int
,
dtype
:
torch
.
dtype
,
):
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
self
.
num_layer
:
int
=
base_hf_config
.
num_hidden_layers
self
.
max_loras_per_batch
:
int
=
max_loras_per_batch
self
.
max_lora_dim
:
int
=
max_lora_dim
self
.
dtype
:
torch
.
dtype
=
dtype
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
# B_buffer contains num_layer number of column-major tensors with shape
# (stacked_num, max_loras_per_batch, output_dim, max_lora_dim)
self
.
A_buffer
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
=
{}
self
.
B_buffer
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
=
{}
# Lora uid -> buffer idx in memory pool
self
.
uid_to_buffer_id
:
Dict
[
Optional
[
str
],
int
]
=
{}
# Buffer idx -> lora uid in memory pool
# All uids are initalized as empty strings for empty buffer slots
# Here we don't initalize to None since None is a valid uid
self
.
buffer_id_to_uid
:
List
[
Optional
[
str
]]
=
[
""
]
*
self
.
max_loras_per_batch
def
init_buffers
(
self
,
lora_weight_names
:
Set
[
Tuple
[
str
]],
base_model
:
torch
.
nn
.
Module
,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self
.
lora_weight_names
:
Set
[
Tuple
[
str
]]
=
lora_weight_names
for
module_A
,
module_B
in
lora_weight_names
:
# Init A tensor, column_major=False
input_dim
,
_
=
get_hidden_dim
(
module_A
,
self
.
base_hf_config
,
base_model
)
c
=
get_stacked_multiply
(
module_A
)
if
module_A
not
in
self
.
A_buffer
:
self
.
A_buffer
[
module_A
]
=
[
torch
.
empty
(
(
self
.
max_loras_per_batch
,
self
.
max_lora_dim
*
c
,
input_dim
,
),
dtype
=
self
.
dtype
,
device
=
"cuda"
,
)
for
i
in
range
(
self
.
num_layer
)
]
# Init B tensor, column_major=True
_
,
output_dim
=
get_hidden_dim
(
module_B
,
self
.
base_hf_config
,
base_model
)
c
=
get_stacked_multiply
(
module_B
)
if
module_B
not
in
self
.
B_buffer
:
self
.
B_buffer
[
module_B
]
=
[
torch
.
empty
(
(
c
,
# stacked lora_b modules might need separation
self
.
max_loras_per_batch
,
output_dim
,
self
.
max_lora_dim
,
),
dtype
=
self
.
dtype
,
device
=
"cuda"
,
)
for
i
in
range
(
self
.
num_layer
)
]
def
prepare_lora_batch
(
self
,
cur_uids
:
Set
[
Optional
[
str
]],
lora_adapters
:
Dict
[
str
,
LoRAAdapter
],
):
def
get_available_buffer_slot
():
for
buffer_id
in
range
(
self
.
max_loras_per_batch
):
# Prioritize empty slots
if
self
.
buffer_id_to_uid
[
buffer_id
]
==
""
:
return
buffer_id
,
""
for
buffer_id
in
range
(
self
.
max_loras_per_batch
):
# Evict unneeded lora
if
self
.
buffer_id_to_uid
[
buffer_id
]
not
in
cur_uids
:
return
buffer_id
,
self
.
buffer_id_to_uid
[
buffer_id
]
raise
ValueError
(
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
)
for
uid
in
cur_uids
:
if
uid
not
in
self
.
uid_to_buffer_id
:
buffer_id
,
evicted_lora_uid
=
get_available_buffer_slot
()
if
evicted_lora_uid
!=
""
:
self
.
uid_to_buffer_id
.
pop
(
evicted_lora_uid
)
self
.
load_lora_weight_to_buffer
(
uid
,
buffer_id
,
lora_adapters
.
get
(
uid
,
None
)
)
self
.
uid_to_buffer_id
[
uid
]
=
buffer_id
self
.
buffer_id_to_uid
[
buffer_id
]
=
uid
def
load_lora_weight_to_buffer
(
self
,
uid
:
str
,
buffer_id
:
int
,
lora_adapter
:
LoRAAdapter
=
None
):
if
uid
is
None
:
for
i
in
range
(
self
.
num_layer
):
for
k
in
self
.
A_buffer
.
keys
():
self
.
A_buffer
[
k
][
i
][
buffer_id
]
*=
0
return
assert
lora_adapter
is
not
None
for
layer_id
in
range
(
self
.
num_layer
):
layer_weights
=
lora_adapter
.
layers
[
layer_id
].
weights
for
name
,
weights
in
layer_weights
.
items
():
if
"lora_A"
in
name
:
lora_weight_name
=
get_weight_name
(
name
,
self
.
lora_weight_names
,
LoRAType
.
LORA_A
)
if
lora_weight_name
:
self
.
A_buffer
[
lora_weight_name
][
layer_id
][
buffer_id
].
copy_
(
weights
)
else
:
lora_weight_name
=
get_weight_name
(
name
,
self
.
lora_weight_names
,
LoRAType
.
LORA_B
)
if
lora_weight_name
:
c
=
get_stacked_multiply
(
lora_weight_name
)
if
c
>
1
:
for
stacked_id
in
range
(
c
):
self
.
B_buffer
[
lora_weight_name
][
layer_id
][
stacked_id
][
buffer_id
].
copy_
(
weights
[
stacked_id
])
else
:
self
.
B_buffer
[
lora_weight_name
][
layer_id
][
0
][
buffer_id
].
copy_
(
weights
)
def
get_tensor
(
self
,
weight_name
:
str
,
layer_id
:
int
,
lora_type
:
LoRAType
)
->
torch
.
Tensor
:
if
lora_type
==
LoRAType
.
LORA_A
:
return
self
.
A_buffer
[
weight_name
][
layer_id
]
return
self
.
B_buffer
[
weight_name
][
layer_id
]
def
get_buffer_id
(
self
,
lora_uid
:
str
):
return
self
.
uid_to_buffer_id
[
lora_uid
]
python/sglang/srt/lora/triton_ops/__init__.py
View file @
c45cab1c
from
.gate_up_lora_b
import
gate_up_lora_b_fwd
from
.qkv_lora_b
import
qkv_lora_b_fwd
from
.sgemm_lora_a
import
sgemm_lora_a_fwd
from
.sgemm_lora_b
import
sgemm_lora_b_fwd
__all__
=
[
"qkv_lora_b_fwd"
,
"sgemm_lora_a_fwd"
,
"sgemm_lora_b_fwd"
]
__all__
=
[
"gate_up_lora_b_fwd"
,
"qkv_lora_b_fwd"
,
"sgemm_lora_a_fwd"
,
"sgemm_lora_b_fwd"
,
]
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
0 → 100644
View file @
c45cab1c
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.lora.utils
import
LoRABatchInfo
@
triton
.
jit
def
_gate_up_lora_b_kernel
(
# Pointers to matrices
x
,
weights
,
output
,
# Parameters of size
K
,
# K = R
output_dim
,
# Strides
x_stride_0
,
x_stride_1
,
w_stride_0
,
w_stride_1
,
w_stride_2
,
output_stride_0
,
output_stride_1
,
# Information on sequence lengths and weight id
seg_lens
,
seg_indptr
,
weight_indices
,
# Meta parameters
BLOCK_S
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# For fused output scaling and adding
fuse_scaling_add
,
scaling
,
):
# This kernel packs 2 sgemms (gate/up) into a single kernel.
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
# weights: (num_lora, 2 * output_dim, K)
# output: (s, 2 * output_dim)
# output_dim >> K
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# gate_up_id decides which of gate or up (0: gate, 1: up)
batch_id
=
tl
.
program_id
(
axis
=
2
)
gate_up_id
=
tl
.
program_id
(
axis
=
1
)
pid
=
tl
.
program_id
(
axis
=
0
)
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
n_start
=
gate_up_id
*
output_dim
# offset on output dim
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n
=
tl
.
cdiv
(
output_dim
,
BLOCK_N
)
pid_s
=
pid
//
num_pid_n
pid_n
=
pid
%
num_pid_n
# Create pointers for the first block of x and weights
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset
=
tl
.
arange
(
0
,
BLOCK_S
)
+
pid_s
*
BLOCK_S
n_offset
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
k_offset
=
tl
.
arange
(
0
,
BLOCK_K
)
x_ptrs
=
(
x
+
seg_start
*
x_stride_0
+
(
gate_up_id
*
K
)
*
x_stride_1
)
+
(
s_offset
[:,
None
]
*
x_stride_0
+
k_offset
[
None
,
:]
*
x_stride_1
)
w_ptrs
=
(
weights
+
w_index
*
w_stride_0
+
n_start
*
w_stride_1
)
+
(
k_offset
[:,
None
]
*
w_stride_2
+
n_offset
[
None
,
:]
*
w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum
=
tl
.
zeros
((
BLOCK_S
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
x_tile
=
tl
.
load
(
x_ptrs
,
mask
=
(
s_offset
[:,
None
]
<
seg_len
)
and
(
k_offset
[
None
,
:]
<
K
-
k
*
BLOCK_K
),
other
=
0.0
,
)
w_tile
=
tl
.
load
(
w_ptrs
,
mask
=
(
k_offset
[:,
None
]
<
K
-
k
*
BLOCK_K
)
and
(
n_offset
[
None
,
:]
<
output_dim
),
other
=
0.0
,
)
partial_sum
+=
tl
.
dot
(
x_tile
,
w_tile
)
x_ptrs
+=
BLOCK_K
*
x_stride_1
w_ptrs
+=
BLOCK_K
*
w_stride_2
# Store result to output matrix
partial_sum
*=
scaling
partial_sum
=
partial_sum
.
to
(
x
.
dtype
.
element_ty
)
output_ptr
=
(
output
+
seg_start
*
output_stride_0
+
n_start
*
output_stride_1
)
+
(
s_offset
[:,
None
]
*
output_stride_0
+
n_offset
[
None
,
:]
*
output_stride_1
)
output_mask
=
(
s_offset
[:,
None
]
<
seg_len
)
and
(
n_offset
[
None
,
:]
<
output_dim
)
if
fuse_scaling_add
:
partial_sum
+=
tl
.
load
(
output_ptr
,
mask
=
output_mask
)
tl
.
store
(
output_ptr
,
partial_sum
,
mask
=
output_mask
)
def
gate_up_lora_b_fwd
(
x
:
torch
.
Tensor
,
gate_up_lora_b
:
torch
.
Tensor
,
batch_info
:
LoRABatchInfo
,
output_dim
:
int
,
base_output
:
torch
.
Tensor
=
None
,
scaling
:
float
=
1.0
,
)
->
torch
.
Tensor
:
# x: (s, 2 * r)
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
# output: (s, 2 * output_dim)
# Compute lora_output with shape (s, output_dim) as follows:
# lora_output[:, :output_dim] = sgemm(x[:, :r], gate_up_lora_b[:, :output_dim, :])
# lora_output[:, output_dim:]
# = sgemm(x[:, r:], gate_up_lora_b[:, output_dim:, :])
# Get dims
s
=
x
.
shape
[
0
]
input_dim
=
x
.
shape
[
1
]
r
=
gate_up_lora_b
.
shape
[
-
1
]
assert
input_dim
==
2
*
r
BLOCK_S
=
16
BLOCK_R
=
16
BLOCK_OUT
=
64
grid_b
=
(
triton
.
cdiv
(
batch_info
.
max_len
,
BLOCK_S
)
*
triton
.
cdiv
(
output_dim
,
BLOCK_OUT
),
2
,
# this dimension decides current block computes on gate or up proj
batch_info
.
bs
,
)
if
base_output
is
None
:
output
=
torch
.
empty
((
s
,
2
*
output_dim
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
fuse_scaling_add
=
False
else
:
output
=
base_output
fuse_scaling_add
=
True
_gate_up_lora_b_kernel
[
grid_b
](
x
,
gate_up_lora_b
,
output
,
r
,
output_dim
,
x
.
stride
(
0
),
x
.
stride
(
1
),
gate_up_lora_b
.
stride
(
0
),
gate_up_lora_b
.
stride
(
1
),
gate_up_lora_b
.
stride
(
2
),
output
.
stride
(
0
),
output
.
stride
(
1
),
batch_info
.
seg_lens
,
batch_info
.
seg_indptr
,
batch_info
.
weight_indices
,
BLOCK_S
,
BLOCK_OUT
,
BLOCK_R
,
fuse_scaling_add
,
scaling
,
)
return
output
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
View file @
c45cab1c
...
...
@@ -2,7 +2,7 @@ import torch
import
triton
import
triton.language
as
tl
from
sglang.srt.lora.
lora
import
Lo
ra
BatchInfo
from
sglang.srt.lora.
utils
import
Lo
RA
BatchInfo
@
triton
.
jit
...
...
@@ -108,7 +108,7 @@ def _qkv_lora_b_kernel(
def
qkv_lora_b_fwd
(
x
:
torch
.
Tensor
,
qkv_lora_b
:
torch
.
Tensor
,
batch_info
:
Lo
ra
BatchInfo
,
batch_info
:
Lo
RA
BatchInfo
,
output_offset
:
torch
.
Tensor
,
max_qkv_out_dim
:
int
,
base_output
:
torch
.
Tensor
=
None
,
...
...
@@ -123,11 +123,11 @@ def qkv_lora_b_fwd(
# output: (s, output_dim_q + 2 * output_dim_kv)
# Compute lora_output with shape (s, output_dim) as follows:
# lora_output[:, :output_dim_q] = sgemm(
lora_output_a[:, :r],
)
# lora_output[:, :output_dim_q] = sgemm(
x[:, :r], qkv_lora_b[:, :outptu_dim_q, :]
)
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
# = sgemm(
lora_output_a
[:, r: 2 * r], kv_lora_b[
0
])
# = sgemm(
x
[:, r: 2 * r],
q
kv_lora_b[
:, outptu_dim_q: output_dim_q + output_dim_kv, :
])
# lora_output[:, output_dim_q + output_dim_kv: ]
# = sgemm(
lora_output_a
[:, 2 * r:
3 * r]
, kv_lora_b[
1
])
# = sgemm(
x
[:, 2 * r: ,
q
kv_lora_b[
:, output_dim_q + output_dim_kv: , :
])
# Get dims
s
=
x
.
shape
[
0
]
...
...
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
View file @
c45cab1c
...
...
@@ -2,7 +2,7 @@ import torch
import
triton
import
triton.language
as
tl
from
sglang.srt.lora.
lora
import
Lo
ra
BatchInfo
from
sglang.srt.lora.
utils
import
Lo
RA
BatchInfo
@
triton
.
jit
...
...
@@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel(
def
sgemm_lora_a_fwd
(
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
batch_info
:
Lo
ra
BatchInfo
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
batch_info
:
Lo
RA
BatchInfo
)
->
torch
.
Tensor
:
# x: (s, input_dim)
# weights: (num_lora, r, input_dim)
...
...
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
View file @
c45cab1c
...
...
@@ -2,7 +2,7 @@ import torch
import
triton
import
triton.language
as
tl
from
sglang.srt.lora.
lora
import
Lo
ra
BatchInfo
from
sglang.srt.lora.
utils
import
Lo
RA
BatchInfo
@
triton
.
jit
...
...
@@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel(
def
sgemm_lora_b_fwd
(
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
batch_info
:
Lo
ra
BatchInfo
,
batch_info
:
Lo
RA
BatchInfo
,
base_output
:
torch
.
Tensor
=
None
,
scaling
:
float
=
1.0
,
)
->
torch
.
Tensor
:
...
...
python/sglang/srt/lora/utils.py
0 → 100644
View file @
c45cab1c
import
re
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Optional
,
Set
,
Tuple
import
torch
from
sglang.srt.hf_transformers_utils
import
AutoConfig
@
dataclass
class
LoRABatchInfo
:
# Batch size
bs
:
int
# Lengths of each sequence in shape (bs,)
seg_lens
:
torch
.
Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr
:
torch
.
Tensor
# Maximum sequence length of current batch
max_len
:
int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices
:
torch
.
Tensor
class
LoRAType
(
Enum
):
LORA_A
=
0
LORA_B
=
1
def
get_layer_id
(
name
:
str
)
->
int
:
"""
Extract integer id of layer from its name in string.
"""
match
=
re
.
search
(
r
"layers\.(\d+)\."
,
name
)
if
match
is
None
:
return
None
return
int
(
match
.
group
(
1
))
def
get_customized_names_from_hf_names
(
hf_module_names
:
Set
[
str
],
base_model
:
torch
.
nn
.
Module
)
->
Set
[
str
]:
"""
This function takes in a set of huggingface style module names:
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
and outputs a set of module names of customized sglang layers:
e.g., {"qkv_proj", "o_proj"}
"""
if
hasattr
(
base_model
,
"get_module_name"
):
return
{
base_model
.
get_module_name
(
name
)
for
name
in
hf_module_names
}
else
:
"""
Fallback solution of mapping from config module name to module name in model class.
Please check if it aligns with your base model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
params_mapping
=
{
"q_proj"
:
"qkv_proj"
,
"k_proj"
:
"qkv_proj"
,
"v_proj"
:
"qkv_proj"
,
"gate_proj"
:
"gate_up_proj"
,
"up_proj"
:
"gate_up_proj"
,
}
return
{
params_mapping
.
get
(
name
,
name
)
for
name
in
hf_module_names
}
def
get_hidden_dim
(
module_name
:
str
,
config
:
AutoConfig
,
base_model
:
torch
.
nn
.
Module
)
->
Tuple
[
int
]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
"""
if
hasattr
(
base_model
,
"get_hidden_dim"
):
return
base_model
.
get_hidden_dim
(
module_name
)
else
:
"""
WARNING: get_hidden_dim() is not defined,
which is used to get the hidden dim for different lora modules
Use the default one, but please check if it is correct for your model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
if
module_name
in
[
"q_proj"
,
"o_proj"
,
"qkv_proj"
]:
return
config
.
hidden_size
,
config
.
hidden_size
elif
module_name
in
[
"kv_proj"
]:
return
config
.
hidden_size
,
config
.
hidden_size
//
(
config
.
num_attention_heads
//
config
.
num_key_value_heads
)
elif
module_name
==
"gate_up_proj"
:
return
config
.
hidden_size
,
config
.
intermediate_size
elif
module_name
==
"down_proj"
:
return
config
.
intermediate_size
,
config
.
hidden_size
else
:
raise
NotImplementedError
()
def
get_stacked_name
(
name
:
str
)
->
Tuple
[
str
]:
"""
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
"""
params_mapping
=
{
"q_proj"
:
(
"qkv_proj"
,
"q_proj"
),
"k_proj"
:
(
"qkv_proj"
,
"kv_proj"
),
"v_proj"
:
(
"qkv_proj"
,
"kv_proj"
),
"gate_proj"
:
(
"gate_up_proj"
,
"gate_up_proj"
),
"up_proj"
:
(
"gate_up_proj"
,
"gate_up_proj"
),
}
return
params_mapping
.
get
(
name
,
(
name
,
name
))
def
get_stacked_multiply
(
module_name
:
str
)
->
int
:
"""
Mapping a lora module name to its magnification at output dimension
"""
stacked_rank
=
{
"qkv_proj"
:
3
,
"kv_proj"
:
2
,
"gate_up_proj"
:
2
,
}
return
stacked_rank
[
module_name
]
if
module_name
in
stacked_rank
else
1
def
get_weight_name
(
target_name
:
str
,
lora_weight_names
:
Set
[
Tuple
[
str
]],
lora_type
:
LoRAType
)
->
Optional
[
str
]:
"""
target_name is name of a given module,
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
If there is a weight name in lora_weight_names that can match target_name, return this name
Else return None
"""
idx
=
0
if
lora_type
==
LoRAType
.
LORA_A
else
1
for
weight_name_pair
in
lora_weight_names
:
if
weight_name_pair
[
idx
]
in
target_name
:
return
weight_name_pair
[
idx
]
test/srt/models/test_lora_backend.py
View file @
c45cab1c
...
...
@@ -22,7 +22,11 @@ from sglang.test.test_utils import calculate_rouge_l
LORA_SETS
=
[
{
"base"
:
"meta-llama/Llama-2-7b-hf"
,
"loras"
:
[
"winddude/wizardLM-LlaMA-LoRA-7B"
]},
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]}
{
"base"
:
"meta-llama/Llama-3.1-8B-Instruct"
,
"loras"
:
[
"reissbaker/llama-3.1-8b-abliterated-lora"
],
"decode_tolerance"
:
8e-2
,
},
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
@@ -128,7 +132,8 @@ class TestLoRABackend(unittest.TestCase):
torch
.
max
(
abs
(
hf_logprobs
-
hf_no_lora_logprobs
)),
)
if
hf_logprobs
.
shape
[
0
]
<=
100
:
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
prefill_tolerance
),
(
tol
=
lora_set
.
get
(
"prefill_tolerance"
,
prefill_tolerance
)
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
tol
),
(
f
"prefill logprobs are not all close with model_path=
{
base_path
}
,"
f
"lora_path=
{
batch_lora_paths
[
i
]
}
, backend=
{
backend
}
, prompt=
{
prompts
[
i
]
}
"
f
"prefill_tolerance=
{
prefill_tolerance
}
."
...
...
@@ -144,7 +149,8 @@ class TestLoRABackend(unittest.TestCase):
"
\n
"
,
)
if
hf_logprobs
.
shape
[
0
]
<=
100
:
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
decode_tolerance
),
(
tol
=
lora_set
.
get
(
"decode_tolerance"
,
decode_tolerance
)
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
tol
),
(
f
"decode logprobs are not all close with model_path=
{
base_path
}
,"
f
"lora_path=
{
batch_lora_paths
[
i
]
}
, backend=
{
backend
}
, prompt=
{
prompts
[
i
]
}
"
f
"decode_tolerance=
{
decode_tolerance
}
."
...
...
@@ -153,7 +159,7 @@ class TestLoRABackend(unittest.TestCase):
# compare output strings
srt_output_str
=
srt_outputs
.
output_strs
[
i
].
strip
(
" "
)
hf_output_str
=
hf_outputs
.
output_strs
[
i
]
hf_output_str
=
hf_outputs
.
output_strs
[
i
]
.
strip
(
" "
)
print
(
f
"srt_output_str=
{
srt_output_str
}
"
)
print
(
f
"hf_output_str=
{
hf_output_str
}
"
)
rouge_l_scores
=
calculate_rouge_l
([
srt_output_str
],
[
hf_output_str
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment