Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c45cab1c
Unverified
Commit
c45cab1c
authored
Feb 09, 2025
by
Baizhou Zhang
Committed by
GitHub
Feb 10, 2025
Browse files
[Fix] Fix accuracy bug and refactor codes for lora (#3413)
parent
27c4c9cf
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1137 additions
and
631 deletions
+1137
-631
python/sglang/srt/lora/backend/__init__.py
python/sglang/srt/lora/backend/__init__.py
+25
-5
python/sglang/srt/lora/backend/base_backend.py
python/sglang/srt/lora/backend/base_backend.py
+31
-9
python/sglang/srt/lora/backend/flashinfer_backend.py
python/sglang/srt/lora/backend/flashinfer_backend.py
+41
-4
python/sglang/srt/lora/backend/triton_backend.py
python/sglang/srt/lora/backend/triton_backend.py
+34
-4
python/sglang/srt/lora/layers.py
python/sglang/srt/lora/layers.py
+293
-0
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+101
-326
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+101
-269
python/sglang/srt/lora/mem_pool.py
python/sglang/srt/lora/mem_pool.py
+174
-0
python/sglang/srt/lora/triton_ops/__init__.py
python/sglang/srt/lora/triton_ops/__init__.py
+7
-1
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
+170
-0
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
+5
-5
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
+2
-2
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
+2
-2
python/sglang/srt/lora/utils.py
python/sglang/srt/lora/utils.py
+141
-0
test/srt/models/test_lora_backend.py
test/srt/models/test_lora_backend.py
+10
-4
No files found.
python/sglang/srt/lora/backend/__init__.py
View file @
c45cab1c
from
.base_backend
import
BaseLoraBackend
from
.base_backend
import
BaseLoRABackend
from
.flashinfer_backend
import
FlashInferLoraBackend
from
.flashinfer_backend
import
FlashInferLoRABackend
from
.triton_backend
import
TritonLoraBackend
from
.triton_backend
import
TritonLoRABackend
def
get_backend_from_name
(
name
:
str
)
->
BaseLoRABackend
:
"""
Get corresponding backend class from backend's name
"""
backend_mapping
=
{
"triton"
:
TritonLoRABackend
,
"flashinfer"
:
FlashInferLoRABackend
,
}
if
name
in
backend_mapping
:
return
backend_mapping
[
name
]
raise
Exception
(
f
"No supported lora backend called
{
name
}
. It should be one of
{
list
(
backend_mapping
.
keys
())
}
"
)
__all__
=
[
__all__
=
[
"FlashInferLoraBackend"
,
"BaseLoRABackend"
,
"TritonLoraBackend"
,
"FlashInferLoRABackend"
,
"TritonLoRABackend"
,
"get_backend_from_name"
,
]
]
python/sglang/srt/lora/backend/base_backend.py
View file @
c45cab1c
...
@@ -2,7 +2,7 @@ from typing import Tuple, Union
...
@@ -2,7 +2,7 @@ from typing import Tuple, Union
import
torch
import
torch
from
sglang.srt.lora.
lora
import
Lo
ra
BatchInfo
from
sglang.srt.lora.
utils
import
Lo
RA
BatchInfo
def
get_fuse_output_scaling_add_from_name
(
name
:
str
)
->
bool
:
def
get_fuse_output_scaling_add_from_name
(
name
:
str
)
->
bool
:
...
@@ -13,7 +13,7 @@ def get_fuse_output_scaling_add_from_name(name: str) -> bool:
...
@@ -13,7 +13,7 @@ def get_fuse_output_scaling_add_from_name(name: str) -> bool:
return
mapping
.
get
(
name
,
False
)
return
mapping
.
get
(
name
,
False
)
def
get_fuse_
qkv
_lora_b_from_name
(
name
:
str
)
->
bool
:
def
get_fuse_
stacked
_lora_b_from_name
(
name
:
str
)
->
bool
:
mapping
=
{
mapping
=
{
"triton"
:
True
,
"triton"
:
True
,
"flashinfer"
:
False
,
"flashinfer"
:
False
,
...
@@ -21,7 +21,7 @@ def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
...
@@ -21,7 +21,7 @@ def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
return
mapping
.
get
(
name
,
False
)
return
mapping
.
get
(
name
,
False
)
class
BaseLo
ra
Backend
:
class
BaseLo
RA
Backend
:
"""Base class for different Lora backends.
"""Base class for different Lora backends.
Each backend has its own implementation of Lora kernels.
Each backend has its own implementation of Lora kernels.
...
@@ -32,11 +32,11 @@ class BaseLoraBackend:
...
@@ -32,11 +32,11 @@ class BaseLoraBackend:
and the operation of scaling and adding will be fused into kernel
and the operation of scaling and adding will be fused into kernel
"""
"""
def
__init__
(
self
,
name
:
str
,
batch_info
:
Lo
ra
BatchInfo
=
None
):
def
__init__
(
self
,
name
:
str
,
batch_info
:
Lo
RA
BatchInfo
=
None
):
self
.
name
=
name
self
.
name
=
name
self
.
batch_info
=
batch_info
self
.
batch_info
=
batch_info
self
.
fuse_output_scaling_add
=
get_fuse_output_scaling_add_from_name
(
name
)
self
.
fuse_output_scaling_add
=
get_fuse_output_scaling_add_from_name
(
name
)
self
.
fuse_
qkv
_lora_b
=
get_fuse_
qkv
_lora_b_from_name
(
name
)
self
.
fuse_
stacked
_lora_b
=
get_fuse_
stacked
_lora_b_from_name
(
name
)
def
run_lora_a_sgemm
(
def
run_lora_a_sgemm
(
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
...
@@ -46,10 +46,11 @@ class BaseLoraBackend:
...
@@ -46,10 +46,11 @@ class BaseLoraBackend:
Args:
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
usually input_dim is much larger than r
usually input_dim is much larger than r
Returns:
Returns:
result with shape (s, r)
result with shape (s,
c *
r)
"""
"""
pass
pass
...
@@ -83,7 +84,7 @@ class BaseLoraBackend:
...
@@ -83,7 +84,7 @@ class BaseLoraBackend:
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
qkv_lora_b: lora_b module for qkv.
qkv_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
If passed in as a tuple of two tensors contain
ing
:
If passed in as a tuple of two tensors
, it should
contain:
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
Returns:
Returns:
...
@@ -91,5 +92,26 @@ class BaseLoraBackend:
...
@@ -91,5 +92,26 @@ class BaseLoraBackend:
"""
"""
pass
pass
def
set_batch_info
(
self
,
batch_info
:
LoraBatchInfo
):
def
run_gate_up_lora
(
self
,
x
:
torch
.
Tensor
,
gate_up_lora_a
:
torch
.
Tensor
,
gate_up_lora_b
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
]],
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
gate_up_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
Returns:
result with shape (s, 2 * output_dim)
"""
pass
def
set_batch_info
(
self
,
batch_info
:
LoRABatchInfo
):
self
.
batch_info
=
batch_info
self
.
batch_info
=
batch_info
python/sglang/srt/lora/backend/flashinfer_backend.py
View file @
c45cab1c
...
@@ -2,17 +2,17 @@ from typing import Tuple
...
@@ -2,17 +2,17 @@ from typing import Tuple
import
torch
import
torch
from
sglang.srt.lora.backend
import
BaseLo
ra
Backend
from
sglang.srt.lora.backend
import
BaseLo
RA
Backend
from
sglang.srt.lora.
lora
import
Lo
ra
BatchInfo
from
sglang.srt.lora.
utils
import
Lo
RA
BatchInfo
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
is_flashinfer_available
if
is_flashinfer_available
():
if
is_flashinfer_available
():
from
flashinfer
import
SegmentGEMMWrapper
from
flashinfer
import
SegmentGEMMWrapper
class
FlashInferLo
ra
Backend
(
BaseLo
ra
Backend
):
class
FlashInferLo
RA
Backend
(
BaseLo
RA
Backend
):
def
__init__
(
self
,
name
:
str
,
batch_info
:
Lo
ra
BatchInfo
=
None
):
def
__init__
(
self
,
name
:
str
,
batch_info
:
Lo
RA
BatchInfo
=
None
):
super
().
__init__
(
name
,
batch_info
)
super
().
__init__
(
name
,
batch_info
)
# Set up SGemm Wrapper from flashinfer
# Set up SGemm Wrapper from flashinfer
...
@@ -55,6 +55,8 @@ class FlashInferLoraBackend(BaseLoraBackend):
...
@@ -55,6 +55,8 @@ class FlashInferLoraBackend(BaseLoraBackend):
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
isinstance
(
qkv_lora_b
,
tuple
)
and
len
(
qkv_lora_b
)
==
2
# Shape of lora_a_output: (s, 3 * r)
# Shape of lora_a_output: (s, 3 * r)
lora_a_output
=
self
.
run_lora_a_sgemm
(
x
=
x
,
weights
=
qkv_lora_a
)
lora_a_output
=
self
.
run_lora_a_sgemm
(
x
=
x
,
weights
=
qkv_lora_a
)
...
@@ -89,3 +91,38 @@ class FlashInferLoraBackend(BaseLoraBackend):
...
@@ -89,3 +91,38 @@ class FlashInferLoraBackend(BaseLoraBackend):
)
)
return
lora_output
return
lora_output
def
run_gate_up_lora
(
self
,
x
:
torch
.
Tensor
,
gate_up_lora_a
:
torch
.
Tensor
,
gate_up_lora_b
:
Tuple
[
torch
.
Tensor
],
*
args
,
**
kwargs
,
)
->
torch
.
Tensor
:
assert
isinstance
(
gate_up_lora_b
,
tuple
)
and
len
(
gate_up_lora_b
)
==
2
lora_rank
=
gate_up_lora_b
[
0
].
shape
[
-
1
]
output_dim
=
gate_up_lora_b
[
0
].
shape
[
-
2
]
# Shape of lora_a_output: (s, 2 * r)
lora_a_output
=
self
.
run_lora_a_sgemm
(
x
=
x
,
weights
=
gate_up_lora_a
)
lora_output
=
torch
.
empty
(
(
x
.
shape
[
0
],
2
*
output_dim
),
device
=
x
.
device
,
dtype
=
x
.
dtype
,
)
# Compute lora for gate and up proj respectively
lora_output
[:,
:
output_dim
]
=
self
.
run_lora_b_sgemm
(
x
=
lora_a_output
[:,
:
lora_rank
].
contiguous
(),
weights
=
gate_up_lora_b
[
0
],
)
lora_output
[:,
output_dim
:]
=
self
.
run_lora_b_sgemm
(
x
=
lora_a_output
[:,
lora_rank
:].
contiguous
(),
weights
=
gate_up_lora_b
[
1
],
)
return
lora_output
python/sglang/srt/lora/backend/triton_backend.py
View file @
c45cab1c
import
torch
import
torch
from
sglang.srt.lora.backend
import
BaseLoraBackend
from
sglang.srt.lora.backend
import
BaseLoRABackend
from
sglang.srt.lora.lora
import
LoraBatchInfo
from
sglang.srt.lora.triton_ops
import
(
from
sglang.srt.lora.triton_ops
import
(
gate_up_lora_b_fwd
,
qkv_lora_b_fwd
,
qkv_lora_b_fwd
,
sgemm_lora_a_fwd
,
sgemm_lora_a_fwd
,
sgemm_lora_b_fwd
,
sgemm_lora_b_fwd
,
)
)
from
sglang.srt.lora.utils
import
LoRABatchInfo
class
TritonLo
ra
Backend
(
BaseLo
ra
Backend
):
class
TritonLo
RA
Backend
(
BaseLo
RA
Backend
):
def
__init__
(
self
,
name
:
str
,
batch_info
:
Lo
ra
BatchInfo
=
None
):
def
__init__
(
self
,
name
:
str
,
batch_info
:
Lo
RA
BatchInfo
=
None
):
super
().
__init__
(
name
,
batch_info
)
super
().
__init__
(
name
,
batch_info
)
def
run_lora_a_sgemm
(
def
run_lora_a_sgemm
(
...
@@ -59,3 +60,32 @@ class TritonLoraBackend(BaseLoraBackend):
...
@@ -59,3 +60,32 @@ class TritonLoraBackend(BaseLoraBackend):
scaling
,
scaling
,
)
)
return
lora_output
return
lora_output
def
run_gate_up_lora
(
self
,
x
:
torch
.
Tensor
,
gate_up_lora_a
:
torch
.
Tensor
,
gate_up_lora_b
:
torch
.
Tensor
,
base_output
:
torch
.
Tensor
=
None
,
scaling
:
float
=
1.0
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
# x: (s, input_dim)
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
assert
isinstance
(
gate_up_lora_b
,
torch
.
Tensor
)
output_dim
=
gate_up_lora_b
.
shape
[
-
2
]
//
2
# lora_a_output: (s, 2 * r)
lora_a_output
=
sgemm_lora_a_fwd
(
x
,
gate_up_lora_a
,
self
.
batch_info
)
lora_output
=
gate_up_lora_b_fwd
(
lora_a_output
,
gate_up_lora_b
,
self
.
batch_info
,
output_dim
,
base_output
,
scaling
,
)
return
lora_output
python/sglang/srt/lora/layers.py
0 → 100644
View file @
c45cab1c
import
torch
from
torch
import
nn
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.lora.backend
import
BaseLoRABackend
class
BaseLayerWithLoRA
(
nn
.
Module
):
def
__init__
(
self
,
base_layer
:
nn
.
Module
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
):
super
().
__init__
()
self
.
base_layer
:
nn
.
Module
=
base_layer
self
.
lora_rank
:
int
=
lora_rank
self
.
scaling
:
float
=
scaling
self
.
set_lora
:
bool
=
False
self
.
lora_backend
:
BaseLoRABackend
=
lora_backend
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
self
.
base_layer
.
forward
(
x
)
def
set_lora_info
(
self
,
*
args
):
pass
class
VocabParallelEmbeddingWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
VocabParallelEmbedding
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
self
.
weight
=
base_layer
.
weight
class
ColumnParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
ColumnParallelLinear
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer
:
torch
.
Tensor
,
B_buffer
:
torch
.
Tensor
,
):
self
.
set_lora
=
True
self
.
A_buffer
=
A_buffer
self
.
B_buffer
=
B_buffer
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
backend_kwargs
=
{
"base_output"
:
base_output
,
"scaling"
:
self
.
scaling
}
lora_a_output
=
self
.
lora_backend
.
run_lora_a_sgemm
(
x
,
self
.
A_buffer
)
lora_output
=
self
.
lora_backend
.
run_lora_b_sgemm
(
lora_a_output
,
self
.
B_buffer
[
0
],
**
backend_kwargs
,
)
return
(
lora_output
if
self
.
lora_backend
.
fuse_output_scaling_add
else
base_output
+
lora_output
*
self
.
scaling
)
def
forward
(
self
,
input_
:
torch
.
Tensor
):
# duplicate the logic in ColumnParallelLinear
bias
=
self
.
base_layer
.
bias
if
not
self
.
base_layer
.
skip_bias_add
else
None
output_parallel
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
input_
,
bias
)
if
self
.
set_lora
:
output_parallel
=
self
.
apply_lora
(
output_parallel
,
input_
)
if
self
.
base_layer
.
gather_output
:
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
base_layer
.
bias
if
self
.
base_layer
.
skip_bias_add
else
None
return
output
,
output_bias
class
MergedColumnParallelLinearWithLoRA
(
ColumnParallelLinearWithLoRA
):
def
__init__
(
self
,
base_layer
:
MergedColumnParallelLinear
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer
:
torch
.
Tensor
,
B_buffer
:
torch
.
Tensor
,
):
self
.
set_lora
=
True
self
.
A_buffer_gate_up
=
A_buffer
if
self
.
lora_backend
.
fuse_stacked_lora_b
:
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
self
.
B_buffer_gate_up
=
torch
.
cat
(
(
B_buffer
[
0
],
B_buffer
[
1
]),
dim
=-
2
).
contiguous
()
else
:
self
.
B_buffer_gate_up
=
(
B_buffer
[
0
],
B_buffer
[
1
])
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
backend_kwargs
=
{
"base_output"
:
base_output
,
"scaling"
:
self
.
scaling
}
lora_output
=
self
.
lora_backend
.
run_gate_up_lora
(
x
,
self
.
A_buffer_gate_up
,
self
.
B_buffer_gate_up
,
**
backend_kwargs
,
)
return
(
lora_output
if
self
.
lora_backend
.
fuse_output_scaling_add
else
base_output
+
lora_output
*
self
.
scaling
)
class
QKVParallelLinearWithLoRA
(
ColumnParallelLinearWithLoRA
):
def
init__
(
self
,
base_layer
:
QKVParallelLinear
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer_qkv
:
torch
.
Tensor
,
B_buffer_q
:
torch
.
Tensor
,
B_buffer_kv
:
torch
.
Tensor
,
):
self
.
set_lora
=
True
self
.
A_buffer_qkv
=
A_buffer_qkv
if
self
.
lora_backend
.
fuse_stacked_lora_b
:
assert
(
B_buffer_q
.
shape
[
-
1
]
==
B_buffer_kv
.
shape
[
-
1
]
),
"The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q
,
output_dim_kv
=
B_buffer_q
.
shape
[
-
2
],
B_buffer_kv
.
shape
[
-
2
]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self
.
B_buffer_qkv
=
torch
.
cat
(
(
B_buffer_q
[
0
],
B_buffer_kv
[
0
],
B_buffer_kv
[
1
]),
dim
=-
2
).
contiguous
()
# Offsets of q/k/v in output dimension
self
.
output_offset
=
torch
.
tensor
(
[
0
,
output_dim_q
,
output_dim_q
+
output_dim_kv
,
output_dim_q
+
2
*
output_dim_kv
,
],
dtype
=
torch
.
int32
,
device
=
B_buffer_q
.
device
,
)
# For computing number of launched blocks
self
.
max_qkv_out_dim
=
max
(
output_dim_q
,
output_dim_kv
)
else
:
self
.
B_buffer_qkv
=
(
B_buffer_q
,
B_buffer_kv
,
)
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
backend_kwargs
=
{
"base_output"
:
base_output
,
"scaling"
:
self
.
scaling
}
if
self
.
lora_backend
.
fuse_stacked_lora_b
:
backend_kwargs
[
"output_offset"
]
=
self
.
output_offset
backend_kwargs
[
"max_qkv_out_dim"
]
=
self
.
max_qkv_out_dim
lora_output
=
self
.
lora_backend
.
run_qkv_lora
(
x
,
self
.
A_buffer_qkv
,
self
.
B_buffer_qkv
,
**
backend_kwargs
,
)
return
(
lora_output
if
self
.
lora_backend
.
fuse_output_scaling_add
else
base_output
+
lora_output
*
self
.
scaling
)
class
RowParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
RowParallelLinear
,
lora_rank
:
int
,
scaling
:
float
,
lora_backend
:
BaseLoRABackend
,
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer
:
torch
.
Tensor
,
B_buffer
:
torch
.
Tensor
):
self
.
set_lora
=
True
self
.
A_buffer
=
A_buffer
self
.
B_buffer
=
B_buffer
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
backend_kwargs
=
{
"base_output"
:
base_output
,
"scaling"
:
self
.
scaling
}
lora_a_output
=
self
.
lora_backend
.
run_lora_a_sgemm
(
x
,
self
.
A_buffer
)
lora_output
=
self
.
lora_backend
.
run_lora_b_sgemm
(
lora_a_output
,
self
.
B_buffer
[
0
],
**
backend_kwargs
,
)
return
(
lora_output
if
self
.
lora_backend
.
fuse_output_scaling_add
else
base_output
+
lora_output
*
self
.
scaling
)
def
forward
(
self
,
input_
:
torch
.
Tensor
):
# duplicate the logic in RowParallelLinear
if
self
.
base_layer
.
input_is_parallel
:
input_parallel
=
input_
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
base_layer
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
output_parallel
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
input_parallel
)
if
self
.
set_lora
:
output_parallel
=
self
.
apply_lora
(
output_parallel
,
input_parallel
)
if
self
.
base_layer
.
reduce_results
and
self
.
base_layer
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output_
=
output_parallel
if
not
self
.
base_layer
.
skip_bias_add
:
output
=
(
output_
+
self
.
base_layer
.
bias
if
self
.
base_layer
.
bias
is
not
None
else
output_
)
output_bias
=
None
else
:
output
=
output_
output_bias
=
self
.
base_layer
.
bias
return
output
,
output_bias
def
get_lora_layer
(
layer
:
nn
.
Module
,
lora_rank
:
int
,
scaling
:
int
,
lora_backend
:
BaseLoRABackend
)
->
BaseLayerWithLoRA
:
supported_layer_types
=
{
# the order matters
VocabParallelEmbedding
:
VocabParallelEmbeddingWithLoRA
,
QKVParallelLinear
:
QKVParallelLinearWithLoRA
,
MergedColumnParallelLinear
:
MergedColumnParallelLinearWithLoRA
,
ColumnParallelLinear
:
ColumnParallelLinearWithLoRA
,
RowParallelLinear
:
RowParallelLinearWithLoRA
,
}
for
src_layer_type
,
lora_layer_type
in
supported_layer_types
.
items
():
if
isinstance
(
layer
,
src_layer_type
):
# pylint: disable=unidiomatic-typecheck
ret
=
lora_layer_type
(
layer
,
lora_rank
,
scaling
,
lora_backend
)
return
ret
raise
Exception
(
f
"No corresponding LoRA layer supported for
{
type
(
layer
)
}
."
)
python/sglang/srt/lora/lora.py
View file @
c45cab1c
...
@@ -19,282 +19,25 @@
...
@@ -19,282 +19,25 @@
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
import
re
import
re
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.layers.linear
import
(
from
sglang.srt.configs.load_config
import
LoadConfig
ColumnParallelLinear
,
from
sglang.srt.hf_transformers_utils
import
AutoConfig
MergedColumnParallelLinear
,
from
sglang.srt.lora.backend
import
BaseLoRABackend
QKVParallelLinear
,
from
sglang.srt.lora.lora_config
import
LoRAConfig
RowParallelLinear
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
@
dataclass
class
LoraBatchInfo
:
# Batch size
bs
:
int
# Lengths of each sequence in shape (bs,)
seg_lens
:
torch
.
Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr
:
torch
.
Tensor
# Maximum sequence length of current batch
max_len
:
int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices
:
torch
.
Tensor
class
BaseLayerWithLoRA
(
nn
.
Module
):
def
__init__
(
self
,
base_layer
,
lora_rank
,
scaling
,
lora_backend
):
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
lora_rank
=
lora_rank
self
.
scaling
=
scaling
self
.
set_lora
=
False
self
.
lora_backend
=
lora_backend
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
self
.
base_layer
.
forward
(
x
)
def
set_lora_info
(
self
,
*
args
):
pass
class
VocabParallelEmbeddingWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
VocabParallelEmbedding
,
lora_rank
,
scaling
,
lora_backend
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
self
.
weight
=
base_layer
.
weight
class
ColumnParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
ColumnParallelLinear
,
lora_rank
,
scaling
,
lora_backend
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
apply_lora
(
self
,
output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# TODO
return
output
def
forward
(
self
,
input_
:
torch
.
Tensor
):
# duplicate the logic in ColumnParallelLinear
bias
=
self
.
base_layer
.
bias
if
not
self
.
base_layer
.
skip_bias_add
else
None
output_parallel
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
input_
,
bias
)
if
self
.
set_lora
:
output_parallel
=
self
.
apply_lora
(
output_parallel
,
input_
)
if
self
.
base_layer
.
gather_output
:
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
base_layer
.
bias
if
self
.
base_layer
.
skip_bias_add
else
None
return
output
,
output_bias
class
MergedColumnParallelLinearWithLoRA
(
ColumnParallelLinearWithLoRA
):
def
__init__
(
self
,
base_layer
:
MergedColumnParallelLinear
,
lora_rank
,
scaling
,
lora_backend
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer
,
B_buffer
,
):
self
.
set_lora
=
True
self
.
A_buffer
=
A_buffer
self
.
B_buffer
=
B_buffer
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
lora_a_output
=
self
.
lora_backend
.
run_lora_a_sgemm
(
x
=
x
,
weights
=
self
.
A_buffer
)
output_dim
=
base_output
.
shape
[
-
1
]
lora_output
=
torch
.
empty_like
(
base_output
)
lora_output
[:,
:
output_dim
]
=
self
.
lora_backend
.
run_lora_b_sgemm
(
x
=
lora_a_output
[:,
0
:
self
.
lora_rank
].
contiguous
(),
weights
=
self
.
B_buffer
[
0
],
)
lora_output
[:,
output_dim
:
2
*
output_dim
]
=
(
self
.
lora_backend
.
run_lora_b_sgemm
(
x
=
lora_a_output
[:,
self
.
lora_rank
:
2
*
self
.
lora_rank
].
contiguous
(),
weights
=
self
.
B_buffer
[
1
],
)
)
return
base_output
+
lora_output
*
self
.
scaling
class
QKVParallelLinearWithLoRA
(
ColumnParallelLinearWithLoRA
):
def
init__
(
self
,
base_layer
:
QKVParallelLinear
,
lora_rank
,
scaling
,
lora_backend
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer_qkv
,
B_buffer_q
,
B_buffer_kv
,
):
self
.
set_lora
=
True
self
.
A_buffer_qkv
=
A_buffer_qkv
if
self
.
lora_backend
.
fuse_qkv_lora_b
:
assert
(
B_buffer_q
.
shape
[
-
1
]
==
B_buffer_kv
.
shape
[
-
1
]
),
"The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q
,
output_dim_kv
=
B_buffer_q
.
shape
[
-
2
],
B_buffer_kv
.
shape
[
-
2
]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self
.
B_buffer_qkv
=
torch
.
cat
(
(
B_buffer_q
[
0
],
B_buffer_kv
[
0
],
B_buffer_kv
[
1
]),
dim
=-
2
).
contiguous
()
# Offsets of q/k/v in output dimension
self
.
output_offset
=
torch
.
tensor
(
[
0
,
output_dim_q
,
output_dim_q
+
output_dim_kv
,
output_dim_q
+
2
*
output_dim_kv
,
],
dtype
=
torch
.
int32
,
device
=
B_buffer_q
.
device
,
)
# For computing number of launched blocks
self
.
max_qkv_out_dim
=
max
(
output_dim_q
,
output_dim_kv
)
else
:
self
.
B_buffer_qkv
=
(
B_buffer_q
,
B_buffer_kv
,
)
self
.
output_offset
=
None
self
.
max_qkv_out_dim
=
None
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
lora_output
=
self
.
lora_backend
.
run_qkv_lora
(
x
,
self
.
A_buffer_qkv
,
self
.
B_buffer_qkv
,
output_offset
=
self
.
output_offset
,
max_qkv_out_dim
=
self
.
max_qkv_out_dim
,
base_output
=
base_output
,
scaling
=
self
.
scaling
,
)
return
(
lora_output
if
self
.
lora_backend
.
fuse_output_scaling_add
else
base_output
+
lora_output
*
self
.
scaling
)
class
RowParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
RowParallelLinear
,
lora_rank
,
scaling
,
lora_backend
)
->
None
:
super
().
__init__
(
base_layer
,
lora_rank
,
scaling
,
lora_backend
)
def
set_lora_info
(
self
,
A_buffer
,
B_buffer
):
self
.
set_lora
=
True
self
.
A_buffer
=
A_buffer
self
.
B_buffer
=
B_buffer
def
apply_lora
(
self
,
base_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
lora_a_output
=
self
.
lora_backend
.
run_lora_a_sgemm
(
x
,
self
.
A_buffer
)
lora_output
=
self
.
lora_backend
.
run_lora_b_sgemm
(
lora_a_output
,
self
.
B_buffer
[
0
],
base_output
=
base_output
,
scaling
=
self
.
scaling
,
)
return
(
lora_output
if
self
.
lora_backend
.
fuse_output_scaling_add
else
base_output
+
lora_output
*
self
.
scaling
)
def
forward
(
self
,
input_
):
# duplicate the logic in RowParallelLinear
if
self
.
base_layer
.
input_is_parallel
:
input_parallel
=
input_
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
base_layer
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
output_parallel
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
input_parallel
)
if
self
.
set_lora
:
output_parallel
=
self
.
apply_lora
(
output_parallel
,
input_parallel
)
if
self
.
base_layer
.
reduce_results
and
self
.
base_layer
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output_
=
output_parallel
if
not
self
.
base_layer
.
skip_bias_add
:
output
=
(
output_
+
self
.
base_layer
.
bias
if
self
.
base_layer
.
bias
is
not
None
else
output_
)
output_bias
=
None
else
:
output
=
output_
output_bias
=
self
.
base_layer
.
bias
return
output
,
output_bias
def
get_lora_layer
(
layer
:
nn
.
Module
,
lora_rank
,
scaling
,
lora_backend
)
->
BaseLayerWithLoRA
:
supported_layer_types
=
{
# the order matters
VocabParallelEmbedding
:
VocabParallelEmbeddingWithLoRA
,
QKVParallelLinear
:
QKVParallelLinearWithLoRA
,
MergedColumnParallelLinear
:
MergedColumnParallelLinearWithLoRA
,
ColumnParallelLinear
:
ColumnParallelLinearWithLoRA
,
RowParallelLinear
:
RowParallelLinearWithLoRA
,
}
for
src_layer_type
,
lora_layer_type
in
supported_layer_types
.
items
():
if
isinstance
(
layer
,
src_layer_type
):
# pylint: disable=unidiomatic-typecheck
ret
=
lora_layer_type
(
layer
,
lora_rank
,
scaling
,
lora_backend
)
return
ret
raise
Exception
(
f
"No corresponding LoRA layer supported for
{
type
(
layer
)
}
."
)
def
get_mapped_params
(
module_names
):
ret
=
set
()
for
module_name
in
module_names
:
ret
.
add
(
params_mapping
(
module_name
))
return
list
(
ret
)
class
LoRALayer
(
nn
.
Module
):
class
LoRALayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
base_hf_c
onfig
):
def
__init__
(
self
,
config
:
LoRAConfig
,
base_hf_config
:
AutoC
onfig
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
:
LoRAConfig
=
config
self
.
base_hf_config
=
base_hf_config
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
self
.
weights
=
{}
self
.
weights
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
weight_gpu
=
{}
self
.
weight_gpu
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
def
load_to_gpu
(
self
):
def
load_to_gpu
(
self
):
for
name
,
weight
in
self
.
weights
.
items
():
for
name
,
weight
in
self
.
weights
.
items
():
...
@@ -306,33 +49,32 @@ class LoRALayer(nn.Module):
...
@@ -306,33 +49,32 @@ class LoRALayer(nn.Module):
class
LoRAAdapter
(
nn
.
Module
):
class
LoRAAdapter
(
nn
.
Module
):
def
__init__
(
self
,
uid
,
config
,
base_hf_config
,
load_config
,
lora_backend
):
def
__init__
(
self
,
uid
:
str
,
config
:
LoRAConfig
,
base_hf_config
:
AutoConfig
,
load_config
:
LoadConfig
,
lora_backend
:
BaseLoRABackend
,
):
super
().
__init__
()
super
().
__init__
()
self
.
uid
=
uid
self
.
uid
:
str
=
uid
self
.
config
=
config
self
.
config
:
LoRAConfig
=
config
assert
self
.
config
.
hf_config
[
"peft_type"
].
lower
()
==
"lora"
assert
self
.
config
.
hf_config
[
"peft_type"
].
lower
()
==
"lora"
self
.
base_hf_config
=
base_hf_config
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
self
.
load_config
=
load_config
self
.
load_config
:
LoadConfig
=
load_config
self
.
lora_backend
=
lora_backend
self
.
lora_backend
:
BaseLoRABackend
=
lora_backend
self
.
scaling
=
self
.
config
.
lora_alpha
/
self
.
config
.
r
self
.
scaling
:
float
=
self
.
config
.
lora_alpha
/
self
.
config
.
r
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
:
List
[
LoRALayer
]
=
nn
.
ModuleList
(
[
[
LoRALayer
(
config
,
base_hf_config
)
LoRALayer
(
config
,
base_hf_config
)
for
i
in
range
(
base_hf_config
.
num_hidden_layers
)
for
i
in
range
(
base_hf_config
.
num_hidden_layers
)
]
]
)
)
self
.
weights
=
{}
self
.
weights
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
weights_gpu
=
{}
self
.
weights_gpu
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
def
get_stacked_multiply
(
self
,
module_name
):
stacked_rank
=
{
"qkv_proj"
:
3
,
"kv_proj"
:
2
,
"gate_up_proj"
:
2
,
}
return
stacked_rank
[
module_name
]
if
module_name
in
stacked_rank
else
1
def
load_to_gpu
(
self
):
def
load_to_gpu
(
self
):
for
name
,
weight
in
self
.
weights
.
items
():
for
name
,
weight
in
self
.
weights
.
items
():
...
@@ -367,44 +109,77 @@ class LoRAAdapter(nn.Module):
...
@@ -367,44 +109,77 @@ class LoRAAdapter(nn.Module):
for
i
in
range
(
self
.
base_hf_config
.
num_hidden_layers
):
for
i
in
range
(
self
.
base_hf_config
.
num_hidden_layers
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
weight_names
=
[
name
for
name
,
_
in
layer
.
weights
.
items
()]
weight_names
=
[
name
for
name
,
_
in
layer
.
weights
.
items
()]
for
weight_name
in
weight_names
:
self
.
stack_qkv_proj
(
weight_names
,
layer
.
weights
)
if
"k_proj"
in
weight_name
:
self
.
stack_gate_up_proj
(
weight_names
,
layer
.
weights
)
q_name
=
weight_name
.
replace
(
"k_proj"
,
"q_proj"
)
v_name
=
weight_name
.
replace
(
"k_proj"
,
"v_proj"
)
def
stack_qkv_proj
(
self
,
weight_names
:
List
[
str
],
weights
:
Dict
[
str
,
torch
.
Tensor
]):
kv_name
=
weight_name
.
replace
(
"k_proj"
,
"kv_proj"
)
qkv_name
=
weight_name
.
replace
(
"k_proj"
,
"qkv_proj"
)
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
if
"lora_A"
in
weight_name
:
target_module
=
set
()
layer
.
weights
[
qkv_name
]
=
torch
.
cat
(
for
weight_name
in
weight_names
:
(
if
"k_proj"
in
weight_name
:
layer
.
weights
[
q_name
],
target_module
.
add
(
"k_proj"
)
layer
.
weights
[
weight_name
],
if
"q_proj"
in
weight_name
:
layer
.
weights
[
v_name
],
target_module
.
add
(
"q_proj"
)
),
if
"v_proj"
in
weight_name
:
0
,
target_module
.
add
(
"v_proj"
)
)
if
len
(
target_module
)
==
0
:
layer
.
weights
.
pop
(
q_name
)
return
layer
.
weights
.
pop
(
weight_name
)
layer
.
weights
.
pop
(
v_name
)
for
weight_name
in
weight_names
:
else
:
# We assume every lora adaptor should contain lora modules for q_proj
layer
.
weights
[
kv_name
]
=
torch
.
stack
(
if
"q_proj"
in
weight_name
:
[
q_name
=
weight_name
layer
.
weights
[
weight_name
],
k_name
=
weight_name
.
replace
(
"q_proj"
,
"k_proj"
)
layer
.
weights
[
v_name
],
v_name
=
weight_name
.
replace
(
"q_proj"
,
"v_proj"
)
],
kv_name
=
weight_name
.
replace
(
"q_proj"
,
"kv_proj"
)
dim
=
0
,
qkv_name
=
weight_name
.
replace
(
"q_proj"
,
"qkv_proj"
)
)
layer
.
weights
.
pop
(
weight_name
)
# If k_proj doesn't have lora, initialize it to zero
layer
.
weights
.
pop
(
v_name
)
k_proj_weight
=
(
elif
"gate_proj"
in
weight_name
:
weights
[
k_name
]
up_name
=
weight_name
.
replace
(
"gate_proj"
,
"up_proj"
)
if
"k_proj"
in
target_module
gate_up_name
=
weight_name
.
replace
(
"gate_proj"
,
"gate_up_proj"
)
else
torch
.
zeros_like
(
weights
[
v_name
])
if
"lora_A"
in
weight_name
:
)
layer
.
weights
[
gate_up_name
]
=
torch
.
cat
(
if
"lora_A"
in
weight_name
:
(
layer
.
weights
[
weight_name
],
layer
.
weights
[
up_name
]),
0
weights
[
qkv_name
]
=
torch
.
cat
(
)
(
else
:
weights
[
q_name
],
layer
.
weights
[
gate_up_name
]
=
torch
.
stack
(
k_proj_weight
,
[
layer
.
weights
[
weight_name
],
layer
.
weights
[
up_name
]],
dim
=
0
weights
[
v_name
],
)
),
layer
.
weights
.
pop
(
weight_name
)
0
,
layer
.
weights
.
pop
(
up_name
)
)
weights
.
pop
(
q_name
)
if
"k_proj"
in
target_module
:
weights
.
pop
(
k_name
)
weights
.
pop
(
v_name
)
else
:
weights
[
kv_name
]
=
torch
.
stack
(
[
k_proj_weight
,
weights
[
v_name
],
],
dim
=
0
,
)
if
"k_proj"
in
target_module
:
weights
.
pop
(
k_name
)
weights
.
pop
(
v_name
)
def
stack_gate_up_proj
(
self
,
weight_names
:
List
[
str
],
weights
:
Dict
[
str
,
torch
.
Tensor
]
):
for
weight_name
in
weight_names
:
if
"gate_proj"
in
weight_name
:
up_name
=
weight_name
.
replace
(
"gate_proj"
,
"up_proj"
)
gate_up_name
=
weight_name
.
replace
(
"gate_proj"
,
"gate_up_proj"
)
if
"lora_A"
in
weight_name
:
weights
[
gate_up_name
]
=
torch
.
cat
(
(
weights
[
weight_name
],
weights
[
up_name
]),
0
)
else
:
weights
[
gate_up_name
]
=
torch
.
stack
(
[
weights
[
weight_name
],
weights
[
up_name
]],
dim
=
0
)
weights
.
pop
(
weight_name
)
weights
.
pop
(
up_name
)
python/sglang/srt/lora/lora_manager.py
View file @
c45cab1c
...
@@ -16,307 +16,115 @@
...
@@ -16,307 +16,115 @@
# and "Punica: Multi-Tenant LoRA Serving"
# and "Punica: Multi-Tenant LoRA Serving"
import
logging
import
logging
import
r
e
from
typing
import
Dict
,
List
,
Set
,
Tupl
e
import
torch
import
torch
from
sglang.srt.lora.backend
import
FlashInferLoraBackend
,
TritonLoraBackend
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.lora.lora
import
LoRAAdapter
,
LoraBatchInfo
,
get_lora_layer
from
sglang.srt.hf_transformers_utils
import
AutoConfig
from
sglang.srt.lora.backend
import
BaseLoRABackend
,
get_backend_from_name
from
sglang.srt.lora.layers
import
get_lora_layer
from
sglang.srt.lora.lora
import
LoRAAdapter
from
sglang.srt.lora.lora_config
import
LoRAConfig
from
sglang.srt.lora.lora_config
import
LoRAConfig
from
sglang.srt.lora.mem_pool
import
LoRAMemoryPool
from
sglang.srt.lora.utils
import
(
LoRABatchInfo
,
LoRAType
,
get_customized_names_from_hf_names
,
get_layer_id
,
get_stacked_name
,
get_weight_name
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
is_flashinfer_available
,
replace_submodule
from
sglang.srt.utils
import
replace_submodule
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
get_module_name
(
name
):
# Fallback solution of mapping from config module name to module name in model class.
# Please check if it aligns with your base model.
# Please implement the function in the model class if it is not.
# You can reference this function in llama.py.
params_mapping
=
{
"q_proj"
:
"qkv_proj"
,
"k_proj"
:
"qkv_proj"
,
"v_proj"
:
"qkv_proj"
,
"gate_proj"
:
"gate_up_proj"
,
"up_proj"
:
"gate_up_proj"
,
}
return
params_mapping
.
get
(
name
,
name
)
def
get_hidden_dim
(
module_name
,
config
):
# Fallback solution of get_hidden_dim for different modules
# Please check if it aligns with your base model.
# Please implement the function in the model class if it is not.
# You can reference this function in llama.py.
if
module_name
in
[
"q_proj"
,
"o_proj"
,
"qkv_proj"
]:
return
config
.
hidden_size
,
config
.
hidden_size
elif
module_name
in
[
"kv_proj"
]:
return
config
.
hidden_size
,
config
.
hidden_size
//
(
config
.
num_attention_heads
//
config
.
num_key_value_heads
)
elif
module_name
==
"gate_up_proj"
:
return
config
.
hidden_size
,
config
.
intermediate_size
elif
module_name
==
"down_proj"
:
return
config
.
intermediate_size
,
config
.
hidden_size
else
:
raise
NotImplementedError
()
def
get_stacked_name
(
name
):
# origin name -> (name for A, name for B)
params_mapping
=
{
"q_proj"
:
(
"qkv_proj"
,
"q_proj"
),
"k_proj"
:
(
"qkv_proj"
,
"kv_proj"
),
"v_proj"
:
(
"qkv_proj"
,
"kv_proj"
),
"gate_proj"
:
(
"gate_up_proj"
,
"gate_up_proj"
),
"up_proj"
:
(
"gate_up_proj"
,
"gate_up_proj"
),
}
return
params_mapping
.
get
(
name
,
(
name
,
name
))
def
get_backend_from_name
(
name
):
backend_mapping
=
{
"triton"
:
TritonLoraBackend
,
"flashinfer"
:
FlashInferLoraBackend
,
}
if
name
in
backend_mapping
:
return
backend_mapping
[
name
]
raise
Exception
(
f
"No supported lora backend called
{
name
}
. It should be one of
{
list
(
backend_mapping
.
keys
())
}
"
)
def
get_layer_id
(
name
):
match
=
re
.
search
(
r
"layers\.(\d+)\."
,
name
)
if
match
is
None
:
return
None
return
int
(
match
.
group
(
1
))
class
LoRAManager
:
class
LoRAManager
:
def
__init__
(
def
__init__
(
self
,
self
,
base_model
,
base_model
:
torch
.
nn
.
Module
,
lora_paths
,
lora_paths
:
Dict
[
str
,
str
]
,
base_hf_config
,
base_hf_config
:
AutoConfig
,
max_loras_per_batch
,
max_loras_per_batch
:
int
,
load_config
,
load_config
:
LoadConfig
,
dtype
,
dtype
:
torch
.
dtype
,
lora_backend
,
lora_backend
:
str
=
"triton"
,
):
):
self
.
base_model
=
base_model
self
.
base_model
:
torch
.
nn
.
Module
=
base_model
self
.
lora_paths
=
lora_paths
self
.
lora_paths
:
Dict
[
str
,
str
]
=
lora_paths
self
.
base_hf_config
=
base_hf_config
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
self
.
max_loras_per_batch
=
max_loras_per_batch
self
.
max_loras_per_batch
:
int
=
max_loras_per_batch
self
.
load_config
=
load_config
self
.
load_config
:
LoadConfig
=
load_config
self
.
dtype
=
dtype
self
.
dtype
:
torch
.
dtype
=
dtype
logger
.
info
(
f
"Using
{
lora_backend
}
as backend of Lora kernels."
)
# LoRA backend for running sgemm kernels
logger
.
info
(
f
"Using
{
lora_backend
}
as backend of LoRA kernels."
)
backend_type
=
get_backend_from_name
(
lora_backend
)
backend_type
=
get_backend_from_name
(
lora_backend
)
self
.
lora_backend
=
backend_type
(
lora_backend
)
self
.
lora_backend
:
BaseLoRABackend
=
backend_type
(
lora_backend
)
self
.
init_loras
()
self
.
init_loras
()
self
.
init_lora_memory_pool
()
self
.
init_lora_memory_pool
()
self
.
init_lora_batch
()
def
match_target_modules
(
self
,
module_name
):
for
target_module
in
self
.
target_modules
:
if
module_name
.
split
(
"."
)[
-
1
]
==
target_module
:
return
True
return
False
def
get_target_modules
(
self
):
modules
=
[]
for
module_name
,
module
in
self
.
base_model
.
named_modules
():
if
self
.
match_target_modules
(
module_name
):
modules
.
append
((
module_name
,
module
))
return
modules
def
set_lora_module
(
self
,
module_name
,
module
):
lora_module
=
get_lora_layer
(
module
,
self
.
max_lora_dim
,
self
.
scaling
,
self
.
lora_backend
)
replace_submodule
(
self
.
base_model
,
module_name
,
lora_module
)
return
lora_module
def
init_loras
(
self
):
def
init_loras
(
self
):
# get configs and target modules
# Config of each LoRA adapter
self
.
configs
=
{}
self
.
configs
:
Dict
[
str
,
LoRAConfig
]
=
{}
self
.
origin_target_modules
=
set
()
# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
self
.
hf_target_names
:
Set
[
str
]
=
set
()
for
name
,
path
in
self
.
lora_paths
.
items
():
for
name
,
path
in
self
.
lora_paths
.
items
():
self
.
configs
[
name
]
=
LoRAConfig
(
path
)
self
.
configs
[
name
]
=
LoRAConfig
(
path
)
self
.
origin
_target_
modul
es
=
set
(
self
.
origin
_target_
modul
es
)
|
set
(
self
.
hf
_target_
nam
es
=
set
(
self
.
hf
_target_
nam
es
)
|
set
(
self
.
configs
[
name
].
target_modules
self
.
configs
[
name
].
target_modules
)
)
if
hasattr
(
self
.
base_model
,
"get_module_name"
):
self
.
target_modules
=
{
# Target lora weight names for lora_a and lora_b modules repectively.
self
.
base_model
.
get_module_name
(
module
)
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
for
module
in
self
.
origin_target_modules
self
.
lora_weight_names
:
Set
[
Tuple
[
str
]]
=
set
(
}
[
get_stacked_name
(
module
)
for
module
in
self
.
hf_target_names
]
else
:
logger
.
warning
(
"WARNING: get_module_name() is not defined, "
"which is used to map config module name to model implementation module name."
"Use the default one, but please check if it is correct for your model."
)
self
.
target_modules
=
{
get_module_name
(
module
)
for
module
in
self
.
origin_target_modules
}
self
.
target_weights
=
set
(
[
get_stacked_name
(
module
)
for
module
in
self
.
origin_target_modules
]
)
)
# load all weights to cpu
# load all weights to cpu
self
.
loras
=
[]
self
.
loras
:
Dict
[
str
,
LoRAAdapter
]
=
{}
self
.
lora_id
=
{}
for
name
in
self
.
lora_paths
.
keys
():
for
name
in
self
.
lora_paths
.
keys
():
self
.
lora_id
[
name
]
=
len
(
self
.
loras
)
lora_adapter
=
LoRAAdapter
(
self
.
loras
.
append
(
name
,
LoRAAdapter
(
self
.
configs
[
name
],
name
,
self
.
base_hf_config
,
self
.
configs
[
name
],
self
.
load_config
,
self
.
base_hf_config
,
self
.
lora_backend
,
self
.
load_config
,
self
.
lora_backend
,
)
)
)
self
.
loras
[
-
1
].
initialize_weights
()
lora_adapter
.
initialize_weights
()
self
.
loras
[
name
]
=
lora_adapter
# misc lora configs
# misc lora configs
self
.
max_lora_dim
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
# FIXME remove the restrictions after implementing unified paging
self
.
scaling
=
self
.
loras
[
0
].
scaling
self
.
max_lora_dim
:
int
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
# FIXME remove the restrictions
self
.
scaling
:
float
=
list
(
self
.
loras
.
values
())[
0
].
scaling
assert
all
(
x
.
hf_config
[
"r"
]
==
self
.
max_lora_dim
for
x
in
self
.
configs
.
values
())
assert
all
(
x
.
hf_config
[
"r"
]
==
self
.
max_lora_dim
for
x
in
self
.
configs
.
values
())
assert
all
(
x
.
scaling
==
self
.
scaling
for
x
in
self
.
loras
)
assert
all
(
x
.
scaling
==
self
.
scaling
for
x
in
self
.
loras
.
values
()
)
# monkey patch to use the LoRA version
# Convert original model layers to layers with LoRA
self
.
lora_modules
=
[]
self
.
convert_to_lora_layers
()
for
module_name
,
module
in
self
.
get_target_modules
():
self
.
lora_modules
.
append
(
(
module_name
,
self
.
set_lora_module
(
module_name
,
module
))
)
def
init_lora_memory_pool
(
self
):
def
init_lora_memory_pool
(
self
):
# preallocate lora memory pool
# Initialize memory pool
self
.
A_buffer
=
{}
self
.
memory_pool
=
LoRAMemoryPool
(
self
.
B_buffer
=
{}
self
.
base_hf_config
,
self
.
max_loras_per_batch
,
self
.
max_lora_dim
,
self
.
dtype
num_layer
=
self
.
base_hf_config
.
num_hidden_layers
)
for
module_A
,
module_B
in
self
.
target_weights
:
# init A tensor, column_major=True
if
hasattr
(
self
.
base_model
,
"get_hidden_dim"
):
hidden_dim_A
,
_
=
self
.
base_model
.
get_hidden_dim
(
module_A
)
else
:
logger
.
warning
(
"WARNING: get_hidden_dim() is not defined, "
"which is used to get the hidden dim for different lora modules"
"Use the default one, but please check if it is correct for your model."
)
hidden_dim_A
,
_
=
get_hidden_dim
(
module_A
,
self
.
base_hf_config
)
c
=
self
.
loras
[
-
1
].
get_stacked_multiply
(
module_A
)
if
module_A
not
in
self
.
A_buffer
:
self
.
A_buffer
[
module_A
]
=
[
torch
.
empty
(
(
self
.
max_loras_per_batch
,
self
.
max_lora_dim
*
c
,
hidden_dim_A
,
),
dtype
=
self
.
dtype
,
device
=
"cuda"
,
)
for
i
in
range
(
num_layer
)
]
# init B tensor, column_major=True
if
hasattr
(
self
.
base_model
,
"get_hidden_dim"
):
_
,
hidden_dim_B
=
self
.
base_model
.
get_hidden_dim
(
module_B
)
else
:
logger
.
warning
(
"WARNING: get_hidden_dim() is not defined, "
"which is used to get the hidden dim for different lora modules"
"Use the default one, but please check if it is correct for your model."
)
_
,
hidden_dim_B
=
get_hidden_dim
(
module_B
,
self
.
base_hf_config
)
c
=
self
.
loras
[
-
1
].
get_stacked_multiply
(
module_B
)
if
module_B
not
in
self
.
B_buffer
:
self
.
B_buffer
[
module_B
]
=
[
torch
.
empty
(
(
c
,
self
.
max_loras_per_batch
,
hidden_dim_B
,
self
.
max_lora_dim
,
),
dtype
=
self
.
dtype
,
device
=
"cuda"
,
)
for
i
in
range
(
num_layer
)
]
def
init_lora_batch
(
self
):
self
.
active_uids
=
set
()
# set of active loras
self
.
buffer_id
=
{}
# lora uid -> idx in memory pool
def
get_weight_name
(
self
,
name
,
idx
):
for
target_weight_name
in
self
.
target_weights
:
if
target_weight_name
[
idx
]
in
name
:
return
target_weight_name
[
idx
]
def
load_lora
(
self
,
uid
,
buffer_id
):
num_layer
=
self
.
base_hf_config
.
num_hidden_layers
if
uid
is
None
:
for
i
in
range
(
num_layer
):
for
k
in
self
.
A_buffer
.
keys
():
self
.
A_buffer
[
k
][
i
][
buffer_id
]
*=
0
return
for
i
in
range
(
num_layer
):
# Initialize target lora modules in memory pool
layer_weights
=
self
.
loras
[
self
.
lora_id
[
uid
]].
layers
[
i
].
weights
self
.
memory_pool
.
init_buffers
(
self
.
lora_weight_names
,
self
.
base_model
)
for
name
,
weights
in
layer_weights
.
items
():
if
"lora_A"
in
name
:
lora_weight_name
=
self
.
get_weight_name
(
name
,
0
)
if
lora_weight_name
:
self
.
A_buffer
[
lora_weight_name
][
i
][
buffer_id
].
copy_
(
weights
)
else
:
lora_weight_name
=
self
.
get_weight_name
(
name
,
1
)
if
lora_weight_name
:
c
=
self
.
loras
[
-
1
].
get_stacked_multiply
(
lora_weight_name
)
if
c
>
1
:
for
j
in
range
(
c
):
self
.
B_buffer
[
lora_weight_name
][
i
][
j
][
buffer_id
].
copy_
(
weights
[
j
]
)
else
:
self
.
B_buffer
[
lora_weight_name
][
i
][
0
][
buffer_id
].
copy_
(
weights
)
def
prepare_lora_batch
(
self
,
forward_batch
:
ForwardBatch
):
def
prepare_lora_batch
(
self
,
forward_batch
:
ForwardBatch
):
# load active loras into lora memory pool
# load active loras into lora memory pool
cur_uids
=
set
(
forward_batch
.
lora_paths
)
cur_uids
=
set
(
forward_batch
.
lora_paths
)
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
i
=
0
self
.
memory_pool
.
prepare_lora_batch
(
cur_uids
,
self
.
loras
)
j
=
len
(
self
.
active_uids
)
evictable_uids
=
list
(
self
.
active_uids
)
for
uid
in
cur_uids
:
if
uid
not
in
self
.
active_uids
:
if
j
<
self
.
max_loras_per_batch
:
index
=
j
j
+=
1
else
:
while
i
<
len
(
evictable_uids
)
and
evictable_uids
[
i
]
in
cur_uids
:
i
+=
1
assert
i
<
len
(
evictable_uids
)
self
.
active_uids
.
remove
(
evictable_uids
[
i
])
self
.
buffer_id
.
pop
(
evictable_uids
[
i
])
index
=
i
i
+=
1
self
.
load_lora
(
uid
,
index
)
self
.
active_uids
.
add
(
uid
)
self
.
buffer_id
[
uid
]
=
index
# FIXME: Handle lora uid with None more safely
if
cur_uids
==
set
([
None
]):
if
cur_uids
==
set
([
None
]):
return
return
...
@@ -332,9 +140,9 @@ class LoRAManager:
...
@@ -332,9 +140,9 @@ class LoRAManager:
max_len
=
int
(
torch
.
max
(
seg_lens
))
max_len
=
int
(
torch
.
max
(
seg_lens
))
weight_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
weight_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
for
i
,
lora_path
in
enumerate
(
forward_batch
.
lora_paths
):
for
i
,
lora_path
in
enumerate
(
forward_batch
.
lora_paths
):
weight_indices
[
i
]
=
self
.
buffer_id
[
lora_path
]
weight_indices
[
i
]
=
self
.
memory_pool
.
get_
buffer_id
(
lora_path
)
batch_info
=
Lo
ra
BatchInfo
(
batch_info
=
Lo
RA
BatchInfo
(
bs
=
bs
,
bs
=
bs
,
seg_lens
=
seg_lens
,
seg_lens
=
seg_lens
,
seg_indptr
=
seg_indptr
,
seg_indptr
=
seg_indptr
,
...
@@ -346,16 +154,40 @@ class LoRAManager:
...
@@ -346,16 +154,40 @@ class LoRAManager:
# call set_lora_info for each lora modules
# call set_lora_info for each lora modules
for
module_name
,
module
in
self
.
lora_modules
:
for
module_name
,
module
in
self
.
lora_modules
:
layer_id
=
get_layer_id
(
module_name
)
layer_id
=
get_layer_id
(
module_name
)
if
"qkv_proj"
not
in
module_name
:
if
"qkv_proj"
not
in
module_name
:
weight_name
=
self
.
get_weight_name
(
module_name
,
0
)
weight_name
=
get_weight_name
(
module_name
,
self
.
lora_weight_names
,
LoRAType
.
LORA_A
)
module
.
set_lora_info
(
module
.
set_lora_info
(
self
.
A_buffer
[
weight_name
][
layer_id
]
,
self
.
memory_pool
.
get_tensor
(
weight_name
,
layer_id
,
LoRAType
.
LORA_A
)
,
self
.
B_buffer
[
weight_name
][
layer_id
]
,
self
.
memory_pool
.
get_tensor
(
weight_name
,
layer_id
,
LoRAType
.
LORA_B
)
,
)
)
else
:
else
:
module
.
set_lora_info
(
module
.
set_lora_info
(
self
.
A_buffer
[
"qkv_proj"
][
layer_id
],
self
.
memory_pool
.
get_tensor
(
"qkv_proj"
,
layer_id
,
LoRAType
.
LORA_A
),
self
.
B_buffer
[
"q_proj"
][
layer_id
],
self
.
memory_pool
.
get_tensor
(
"q_proj"
,
layer_id
,
LoRAType
.
LORA_B
),
self
.
B_buffer
[
"kv_proj"
][
layer_id
],
self
.
memory_pool
.
get_tensor
(
"kv_proj"
,
layer_id
,
LoRAType
.
LORA_B
),
)
def
set_lora_module
(
self
,
module_name
,
module
):
lora_module
=
get_lora_layer
(
module
,
self
.
max_lora_dim
,
self
.
scaling
,
self
.
lora_backend
)
replace_submodule
(
self
.
base_model
,
module_name
,
lora_module
)
return
lora_module
def
convert_to_lora_layers
(
self
):
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names
=
get_customized_names_from_hf_names
(
self
.
hf_target_names
,
self
.
base_model
)
# Monkey patch to use the LoRA version layers
self
.
lora_modules
:
List
[
Tuple
[
str
,
torch
.
nn
.
Module
]]
=
[]
for
module_name
,
module
in
self
.
base_model
.
named_modules
():
# The module should be converted if it is included in target_names
if
module_name
.
split
(
"."
)[
-
1
]
in
customized_target_names
:
self
.
lora_modules
.
append
(
(
module_name
,
self
.
set_lora_module
(
module_name
,
module
))
)
)
python/sglang/srt/lora/mem_pool.py
0 → 100644
View file @
c45cab1c
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
torch
from
sglang.srt.hf_transformers_utils
import
AutoConfig
from
sglang.srt.lora.lora
import
LoRAAdapter
from
sglang.srt.lora.utils
import
(
LoRAType
,
get_hidden_dim
,
get_stacked_multiply
,
get_weight_name
,
)
class
LoRAMemoryPool
:
"""Class for memory pool management of lora modules"""
def
__init__
(
self
,
base_hf_config
:
AutoConfig
,
max_loras_per_batch
:
int
,
max_lora_dim
:
int
,
dtype
:
torch
.
dtype
,
):
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
self
.
num_layer
:
int
=
base_hf_config
.
num_hidden_layers
self
.
max_loras_per_batch
:
int
=
max_loras_per_batch
self
.
max_lora_dim
:
int
=
max_lora_dim
self
.
dtype
:
torch
.
dtype
=
dtype
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
# B_buffer contains num_layer number of column-major tensors with shape
# (stacked_num, max_loras_per_batch, output_dim, max_lora_dim)
self
.
A_buffer
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
=
{}
self
.
B_buffer
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
=
{}
# Lora uid -> buffer idx in memory pool
self
.
uid_to_buffer_id
:
Dict
[
Optional
[
str
],
int
]
=
{}
# Buffer idx -> lora uid in memory pool
# All uids are initalized as empty strings for empty buffer slots
# Here we don't initalize to None since None is a valid uid
self
.
buffer_id_to_uid
:
List
[
Optional
[
str
]]
=
[
""
]
*
self
.
max_loras_per_batch
def
init_buffers
(
self
,
lora_weight_names
:
Set
[
Tuple
[
str
]],
base_model
:
torch
.
nn
.
Module
,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self
.
lora_weight_names
:
Set
[
Tuple
[
str
]]
=
lora_weight_names
for
module_A
,
module_B
in
lora_weight_names
:
# Init A tensor, column_major=False
input_dim
,
_
=
get_hidden_dim
(
module_A
,
self
.
base_hf_config
,
base_model
)
c
=
get_stacked_multiply
(
module_A
)
if
module_A
not
in
self
.
A_buffer
:
self
.
A_buffer
[
module_A
]
=
[
torch
.
empty
(
(
self
.
max_loras_per_batch
,
self
.
max_lora_dim
*
c
,
input_dim
,
),
dtype
=
self
.
dtype
,
device
=
"cuda"
,
)
for
i
in
range
(
self
.
num_layer
)
]
# Init B tensor, column_major=True
_
,
output_dim
=
get_hidden_dim
(
module_B
,
self
.
base_hf_config
,
base_model
)
c
=
get_stacked_multiply
(
module_B
)
if
module_B
not
in
self
.
B_buffer
:
self
.
B_buffer
[
module_B
]
=
[
torch
.
empty
(
(
c
,
# stacked lora_b modules might need separation
self
.
max_loras_per_batch
,
output_dim
,
self
.
max_lora_dim
,
),
dtype
=
self
.
dtype
,
device
=
"cuda"
,
)
for
i
in
range
(
self
.
num_layer
)
]
def
prepare_lora_batch
(
self
,
cur_uids
:
Set
[
Optional
[
str
]],
lora_adapters
:
Dict
[
str
,
LoRAAdapter
],
):
def
get_available_buffer_slot
():
for
buffer_id
in
range
(
self
.
max_loras_per_batch
):
# Prioritize empty slots
if
self
.
buffer_id_to_uid
[
buffer_id
]
==
""
:
return
buffer_id
,
""
for
buffer_id
in
range
(
self
.
max_loras_per_batch
):
# Evict unneeded lora
if
self
.
buffer_id_to_uid
[
buffer_id
]
not
in
cur_uids
:
return
buffer_id
,
self
.
buffer_id_to_uid
[
buffer_id
]
raise
ValueError
(
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
)
for
uid
in
cur_uids
:
if
uid
not
in
self
.
uid_to_buffer_id
:
buffer_id
,
evicted_lora_uid
=
get_available_buffer_slot
()
if
evicted_lora_uid
!=
""
:
self
.
uid_to_buffer_id
.
pop
(
evicted_lora_uid
)
self
.
load_lora_weight_to_buffer
(
uid
,
buffer_id
,
lora_adapters
.
get
(
uid
,
None
)
)
self
.
uid_to_buffer_id
[
uid
]
=
buffer_id
self
.
buffer_id_to_uid
[
buffer_id
]
=
uid
def
load_lora_weight_to_buffer
(
self
,
uid
:
str
,
buffer_id
:
int
,
lora_adapter
:
LoRAAdapter
=
None
):
if
uid
is
None
:
for
i
in
range
(
self
.
num_layer
):
for
k
in
self
.
A_buffer
.
keys
():
self
.
A_buffer
[
k
][
i
][
buffer_id
]
*=
0
return
assert
lora_adapter
is
not
None
for
layer_id
in
range
(
self
.
num_layer
):
layer_weights
=
lora_adapter
.
layers
[
layer_id
].
weights
for
name
,
weights
in
layer_weights
.
items
():
if
"lora_A"
in
name
:
lora_weight_name
=
get_weight_name
(
name
,
self
.
lora_weight_names
,
LoRAType
.
LORA_A
)
if
lora_weight_name
:
self
.
A_buffer
[
lora_weight_name
][
layer_id
][
buffer_id
].
copy_
(
weights
)
else
:
lora_weight_name
=
get_weight_name
(
name
,
self
.
lora_weight_names
,
LoRAType
.
LORA_B
)
if
lora_weight_name
:
c
=
get_stacked_multiply
(
lora_weight_name
)
if
c
>
1
:
for
stacked_id
in
range
(
c
):
self
.
B_buffer
[
lora_weight_name
][
layer_id
][
stacked_id
][
buffer_id
].
copy_
(
weights
[
stacked_id
])
else
:
self
.
B_buffer
[
lora_weight_name
][
layer_id
][
0
][
buffer_id
].
copy_
(
weights
)
def
get_tensor
(
self
,
weight_name
:
str
,
layer_id
:
int
,
lora_type
:
LoRAType
)
->
torch
.
Tensor
:
if
lora_type
==
LoRAType
.
LORA_A
:
return
self
.
A_buffer
[
weight_name
][
layer_id
]
return
self
.
B_buffer
[
weight_name
][
layer_id
]
def
get_buffer_id
(
self
,
lora_uid
:
str
):
return
self
.
uid_to_buffer_id
[
lora_uid
]
python/sglang/srt/lora/triton_ops/__init__.py
View file @
c45cab1c
from
.gate_up_lora_b
import
gate_up_lora_b_fwd
from
.qkv_lora_b
import
qkv_lora_b_fwd
from
.qkv_lora_b
import
qkv_lora_b_fwd
from
.sgemm_lora_a
import
sgemm_lora_a_fwd
from
.sgemm_lora_a
import
sgemm_lora_a_fwd
from
.sgemm_lora_b
import
sgemm_lora_b_fwd
from
.sgemm_lora_b
import
sgemm_lora_b_fwd
__all__
=
[
"qkv_lora_b_fwd"
,
"sgemm_lora_a_fwd"
,
"sgemm_lora_b_fwd"
]
__all__
=
[
"gate_up_lora_b_fwd"
,
"qkv_lora_b_fwd"
,
"sgemm_lora_a_fwd"
,
"sgemm_lora_b_fwd"
,
]
python/sglang/srt/lora/triton_ops/gate_up_lora_b.py
0 → 100644
View file @
c45cab1c
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.lora.utils
import
LoRABatchInfo
@
triton
.
jit
def
_gate_up_lora_b_kernel
(
# Pointers to matrices
x
,
weights
,
output
,
# Parameters of size
K
,
# K = R
output_dim
,
# Strides
x_stride_0
,
x_stride_1
,
w_stride_0
,
w_stride_1
,
w_stride_2
,
output_stride_0
,
output_stride_1
,
# Information on sequence lengths and weight id
seg_lens
,
seg_indptr
,
weight_indices
,
# Meta parameters
BLOCK_S
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
# For fused output scaling and adding
fuse_scaling_add
,
scaling
,
):
# This kernel packs 2 sgemms (gate/up) into a single kernel.
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
# weights: (num_lora, 2 * output_dim, K)
# output: (s, 2 * output_dim)
# output_dim >> K
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# gate_up_id decides which of gate or up (0: gate, 1: up)
batch_id
=
tl
.
program_id
(
axis
=
2
)
gate_up_id
=
tl
.
program_id
(
axis
=
1
)
pid
=
tl
.
program_id
(
axis
=
0
)
seg_len
=
tl
.
load
(
seg_lens
+
batch_id
)
w_index
=
tl
.
load
(
weight_indices
+
batch_id
)
seg_start
=
tl
.
load
(
seg_indptr
+
batch_id
)
n_start
=
gate_up_id
*
output_dim
# offset on output dim
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n
=
tl
.
cdiv
(
output_dim
,
BLOCK_N
)
pid_s
=
pid
//
num_pid_n
pid_n
=
pid
%
num_pid_n
# Create pointers for the first block of x and weights
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset
=
tl
.
arange
(
0
,
BLOCK_S
)
+
pid_s
*
BLOCK_S
n_offset
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
k_offset
=
tl
.
arange
(
0
,
BLOCK_K
)
x_ptrs
=
(
x
+
seg_start
*
x_stride_0
+
(
gate_up_id
*
K
)
*
x_stride_1
)
+
(
s_offset
[:,
None
]
*
x_stride_0
+
k_offset
[
None
,
:]
*
x_stride_1
)
w_ptrs
=
(
weights
+
w_index
*
w_stride_0
+
n_start
*
w_stride_1
)
+
(
k_offset
[:,
None
]
*
w_stride_2
+
n_offset
[
None
,
:]
*
w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum
=
tl
.
zeros
((
BLOCK_S
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
x_tile
=
tl
.
load
(
x_ptrs
,
mask
=
(
s_offset
[:,
None
]
<
seg_len
)
and
(
k_offset
[
None
,
:]
<
K
-
k
*
BLOCK_K
),
other
=
0.0
,
)
w_tile
=
tl
.
load
(
w_ptrs
,
mask
=
(
k_offset
[:,
None
]
<
K
-
k
*
BLOCK_K
)
and
(
n_offset
[
None
,
:]
<
output_dim
),
other
=
0.0
,
)
partial_sum
+=
tl
.
dot
(
x_tile
,
w_tile
)
x_ptrs
+=
BLOCK_K
*
x_stride_1
w_ptrs
+=
BLOCK_K
*
w_stride_2
# Store result to output matrix
partial_sum
*=
scaling
partial_sum
=
partial_sum
.
to
(
x
.
dtype
.
element_ty
)
output_ptr
=
(
output
+
seg_start
*
output_stride_0
+
n_start
*
output_stride_1
)
+
(
s_offset
[:,
None
]
*
output_stride_0
+
n_offset
[
None
,
:]
*
output_stride_1
)
output_mask
=
(
s_offset
[:,
None
]
<
seg_len
)
and
(
n_offset
[
None
,
:]
<
output_dim
)
if
fuse_scaling_add
:
partial_sum
+=
tl
.
load
(
output_ptr
,
mask
=
output_mask
)
tl
.
store
(
output_ptr
,
partial_sum
,
mask
=
output_mask
)
def
gate_up_lora_b_fwd
(
x
:
torch
.
Tensor
,
gate_up_lora_b
:
torch
.
Tensor
,
batch_info
:
LoRABatchInfo
,
output_dim
:
int
,
base_output
:
torch
.
Tensor
=
None
,
scaling
:
float
=
1.0
,
)
->
torch
.
Tensor
:
# x: (s, 2 * r)
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
# output: (s, 2 * output_dim)
# Compute lora_output with shape (s, output_dim) as follows:
# lora_output[:, :output_dim] = sgemm(x[:, :r], gate_up_lora_b[:, :output_dim, :])
# lora_output[:, output_dim:]
# = sgemm(x[:, r:], gate_up_lora_b[:, output_dim:, :])
# Get dims
s
=
x
.
shape
[
0
]
input_dim
=
x
.
shape
[
1
]
r
=
gate_up_lora_b
.
shape
[
-
1
]
assert
input_dim
==
2
*
r
BLOCK_S
=
16
BLOCK_R
=
16
BLOCK_OUT
=
64
grid_b
=
(
triton
.
cdiv
(
batch_info
.
max_len
,
BLOCK_S
)
*
triton
.
cdiv
(
output_dim
,
BLOCK_OUT
),
2
,
# this dimension decides current block computes on gate or up proj
batch_info
.
bs
,
)
if
base_output
is
None
:
output
=
torch
.
empty
((
s
,
2
*
output_dim
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
fuse_scaling_add
=
False
else
:
output
=
base_output
fuse_scaling_add
=
True
_gate_up_lora_b_kernel
[
grid_b
](
x
,
gate_up_lora_b
,
output
,
r
,
output_dim
,
x
.
stride
(
0
),
x
.
stride
(
1
),
gate_up_lora_b
.
stride
(
0
),
gate_up_lora_b
.
stride
(
1
),
gate_up_lora_b
.
stride
(
2
),
output
.
stride
(
0
),
output
.
stride
(
1
),
batch_info
.
seg_lens
,
batch_info
.
seg_indptr
,
batch_info
.
weight_indices
,
BLOCK_S
,
BLOCK_OUT
,
BLOCK_R
,
fuse_scaling_add
,
scaling
,
)
return
output
python/sglang/srt/lora/triton_ops/qkv_lora_b.py
View file @
c45cab1c
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.lora.
lora
import
Lo
ra
BatchInfo
from
sglang.srt.lora.
utils
import
Lo
RA
BatchInfo
@
triton
.
jit
@
triton
.
jit
...
@@ -108,7 +108,7 @@ def _qkv_lora_b_kernel(
...
@@ -108,7 +108,7 @@ def _qkv_lora_b_kernel(
def
qkv_lora_b_fwd
(
def
qkv_lora_b_fwd
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
qkv_lora_b
:
torch
.
Tensor
,
qkv_lora_b
:
torch
.
Tensor
,
batch_info
:
Lo
ra
BatchInfo
,
batch_info
:
Lo
RA
BatchInfo
,
output_offset
:
torch
.
Tensor
,
output_offset
:
torch
.
Tensor
,
max_qkv_out_dim
:
int
,
max_qkv_out_dim
:
int
,
base_output
:
torch
.
Tensor
=
None
,
base_output
:
torch
.
Tensor
=
None
,
...
@@ -123,11 +123,11 @@ def qkv_lora_b_fwd(
...
@@ -123,11 +123,11 @@ def qkv_lora_b_fwd(
# output: (s, output_dim_q + 2 * output_dim_kv)
# output: (s, output_dim_q + 2 * output_dim_kv)
# Compute lora_output with shape (s, output_dim) as follows:
# Compute lora_output with shape (s, output_dim) as follows:
# lora_output[:, :output_dim_q] = sgemm(
lora_output_a[:, :r],
)
# lora_output[:, :output_dim_q] = sgemm(
x[:, :r], qkv_lora_b[:, :outptu_dim_q, :]
)
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
# = sgemm(
lora_output_a
[:, r: 2 * r], kv_lora_b[
0
])
# = sgemm(
x
[:, r: 2 * r],
q
kv_lora_b[
:, outptu_dim_q: output_dim_q + output_dim_kv, :
])
# lora_output[:, output_dim_q + output_dim_kv: ]
# lora_output[:, output_dim_q + output_dim_kv: ]
# = sgemm(
lora_output_a
[:, 2 * r:
3 * r]
, kv_lora_b[
1
])
# = sgemm(
x
[:, 2 * r: ,
q
kv_lora_b[
:, output_dim_q + output_dim_kv: , :
])
# Get dims
# Get dims
s
=
x
.
shape
[
0
]
s
=
x
.
shape
[
0
]
...
...
python/sglang/srt/lora/triton_ops/sgemm_lora_a.py
View file @
c45cab1c
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.lora.
lora
import
Lo
ra
BatchInfo
from
sglang.srt.lora.
utils
import
Lo
RA
BatchInfo
@
triton
.
jit
@
triton
.
jit
...
@@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel(
...
@@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel(
def
sgemm_lora_a_fwd
(
def
sgemm_lora_a_fwd
(
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
batch_info
:
Lo
ra
BatchInfo
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
batch_info
:
Lo
RA
BatchInfo
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# x: (s, input_dim)
# x: (s, input_dim)
# weights: (num_lora, r, input_dim)
# weights: (num_lora, r, input_dim)
...
...
python/sglang/srt/lora/triton_ops/sgemm_lora_b.py
View file @
c45cab1c
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.lora.
lora
import
Lo
ra
BatchInfo
from
sglang.srt.lora.
utils
import
Lo
RA
BatchInfo
@
triton
.
jit
@
triton
.
jit
...
@@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel(
...
@@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel(
def
sgemm_lora_b_fwd
(
def
sgemm_lora_b_fwd
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
batch_info
:
Lo
ra
BatchInfo
,
batch_info
:
Lo
RA
BatchInfo
,
base_output
:
torch
.
Tensor
=
None
,
base_output
:
torch
.
Tensor
=
None
,
scaling
:
float
=
1.0
,
scaling
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
python/sglang/srt/lora/utils.py
0 → 100644
View file @
c45cab1c
import
re
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Optional
,
Set
,
Tuple
import
torch
from
sglang.srt.hf_transformers_utils
import
AutoConfig
@
dataclass
class
LoRABatchInfo
:
# Batch size
bs
:
int
# Lengths of each sequence in shape (bs,)
seg_lens
:
torch
.
Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr
:
torch
.
Tensor
# Maximum sequence length of current batch
max_len
:
int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices
:
torch
.
Tensor
class
LoRAType
(
Enum
):
LORA_A
=
0
LORA_B
=
1
def
get_layer_id
(
name
:
str
)
->
int
:
"""
Extract integer id of layer from its name in string.
"""
match
=
re
.
search
(
r
"layers\.(\d+)\."
,
name
)
if
match
is
None
:
return
None
return
int
(
match
.
group
(
1
))
def
get_customized_names_from_hf_names
(
hf_module_names
:
Set
[
str
],
base_model
:
torch
.
nn
.
Module
)
->
Set
[
str
]:
"""
This function takes in a set of huggingface style module names:
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
and outputs a set of module names of customized sglang layers:
e.g., {"qkv_proj", "o_proj"}
"""
if
hasattr
(
base_model
,
"get_module_name"
):
return
{
base_model
.
get_module_name
(
name
)
for
name
in
hf_module_names
}
else
:
"""
Fallback solution of mapping from config module name to module name in model class.
Please check if it aligns with your base model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
params_mapping
=
{
"q_proj"
:
"qkv_proj"
,
"k_proj"
:
"qkv_proj"
,
"v_proj"
:
"qkv_proj"
,
"gate_proj"
:
"gate_up_proj"
,
"up_proj"
:
"gate_up_proj"
,
}
return
{
params_mapping
.
get
(
name
,
name
)
for
name
in
hf_module_names
}
def
get_hidden_dim
(
module_name
:
str
,
config
:
AutoConfig
,
base_model
:
torch
.
nn
.
Module
)
->
Tuple
[
int
]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
"""
if
hasattr
(
base_model
,
"get_hidden_dim"
):
return
base_model
.
get_hidden_dim
(
module_name
)
else
:
"""
WARNING: get_hidden_dim() is not defined,
which is used to get the hidden dim for different lora modules
Use the default one, but please check if it is correct for your model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
if
module_name
in
[
"q_proj"
,
"o_proj"
,
"qkv_proj"
]:
return
config
.
hidden_size
,
config
.
hidden_size
elif
module_name
in
[
"kv_proj"
]:
return
config
.
hidden_size
,
config
.
hidden_size
//
(
config
.
num_attention_heads
//
config
.
num_key_value_heads
)
elif
module_name
==
"gate_up_proj"
:
return
config
.
hidden_size
,
config
.
intermediate_size
elif
module_name
==
"down_proj"
:
return
config
.
intermediate_size
,
config
.
hidden_size
else
:
raise
NotImplementedError
()
def
get_stacked_name
(
name
:
str
)
->
Tuple
[
str
]:
"""
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
"""
params_mapping
=
{
"q_proj"
:
(
"qkv_proj"
,
"q_proj"
),
"k_proj"
:
(
"qkv_proj"
,
"kv_proj"
),
"v_proj"
:
(
"qkv_proj"
,
"kv_proj"
),
"gate_proj"
:
(
"gate_up_proj"
,
"gate_up_proj"
),
"up_proj"
:
(
"gate_up_proj"
,
"gate_up_proj"
),
}
return
params_mapping
.
get
(
name
,
(
name
,
name
))
def
get_stacked_multiply
(
module_name
:
str
)
->
int
:
"""
Mapping a lora module name to its magnification at output dimension
"""
stacked_rank
=
{
"qkv_proj"
:
3
,
"kv_proj"
:
2
,
"gate_up_proj"
:
2
,
}
return
stacked_rank
[
module_name
]
if
module_name
in
stacked_rank
else
1
def
get_weight_name
(
target_name
:
str
,
lora_weight_names
:
Set
[
Tuple
[
str
]],
lora_type
:
LoRAType
)
->
Optional
[
str
]:
"""
target_name is name of a given module,
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
If there is a weight name in lora_weight_names that can match target_name, return this name
Else return None
"""
idx
=
0
if
lora_type
==
LoRAType
.
LORA_A
else
1
for
weight_name_pair
in
lora_weight_names
:
if
weight_name_pair
[
idx
]
in
target_name
:
return
weight_name_pair
[
idx
]
test/srt/models/test_lora_backend.py
View file @
c45cab1c
...
@@ -22,7 +22,11 @@ from sglang.test.test_utils import calculate_rouge_l
...
@@ -22,7 +22,11 @@ from sglang.test.test_utils import calculate_rouge_l
LORA_SETS
=
[
LORA_SETS
=
[
{
"base"
:
"meta-llama/Llama-2-7b-hf"
,
"loras"
:
[
"winddude/wizardLM-LlaMA-LoRA-7B"
]},
{
"base"
:
"meta-llama/Llama-2-7b-hf"
,
"loras"
:
[
"winddude/wizardLM-LlaMA-LoRA-7B"
]},
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]}
{
"base"
:
"meta-llama/Llama-3.1-8B-Instruct"
,
"loras"
:
[
"reissbaker/llama-3.1-8b-abliterated-lora"
],
"decode_tolerance"
:
8e-2
,
},
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
@@ -128,7 +132,8 @@ class TestLoRABackend(unittest.TestCase):
...
@@ -128,7 +132,8 @@ class TestLoRABackend(unittest.TestCase):
torch
.
max
(
abs
(
hf_logprobs
-
hf_no_lora_logprobs
)),
torch
.
max
(
abs
(
hf_logprobs
-
hf_no_lora_logprobs
)),
)
)
if
hf_logprobs
.
shape
[
0
]
<=
100
:
if
hf_logprobs
.
shape
[
0
]
<=
100
:
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
prefill_tolerance
),
(
tol
=
lora_set
.
get
(
"prefill_tolerance"
,
prefill_tolerance
)
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
tol
),
(
f
"prefill logprobs are not all close with model_path=
{
base_path
}
,"
f
"prefill logprobs are not all close with model_path=
{
base_path
}
,"
f
"lora_path=
{
batch_lora_paths
[
i
]
}
, backend=
{
backend
}
, prompt=
{
prompts
[
i
]
}
"
f
"lora_path=
{
batch_lora_paths
[
i
]
}
, backend=
{
backend
}
, prompt=
{
prompts
[
i
]
}
"
f
"prefill_tolerance=
{
prefill_tolerance
}
."
f
"prefill_tolerance=
{
prefill_tolerance
}
."
...
@@ -144,7 +149,8 @@ class TestLoRABackend(unittest.TestCase):
...
@@ -144,7 +149,8 @@ class TestLoRABackend(unittest.TestCase):
"
\n
"
,
"
\n
"
,
)
)
if
hf_logprobs
.
shape
[
0
]
<=
100
:
if
hf_logprobs
.
shape
[
0
]
<=
100
:
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
decode_tolerance
),
(
tol
=
lora_set
.
get
(
"decode_tolerance"
,
decode_tolerance
)
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
tol
),
(
f
"decode logprobs are not all close with model_path=
{
base_path
}
,"
f
"decode logprobs are not all close with model_path=
{
base_path
}
,"
f
"lora_path=
{
batch_lora_paths
[
i
]
}
, backend=
{
backend
}
, prompt=
{
prompts
[
i
]
}
"
f
"lora_path=
{
batch_lora_paths
[
i
]
}
, backend=
{
backend
}
, prompt=
{
prompts
[
i
]
}
"
f
"decode_tolerance=
{
decode_tolerance
}
."
f
"decode_tolerance=
{
decode_tolerance
}
."
...
@@ -153,7 +159,7 @@ class TestLoRABackend(unittest.TestCase):
...
@@ -153,7 +159,7 @@ class TestLoRABackend(unittest.TestCase):
# compare output strings
# compare output strings
srt_output_str
=
srt_outputs
.
output_strs
[
i
].
strip
(
" "
)
srt_output_str
=
srt_outputs
.
output_strs
[
i
].
strip
(
" "
)
hf_output_str
=
hf_outputs
.
output_strs
[
i
]
hf_output_str
=
hf_outputs
.
output_strs
[
i
]
.
strip
(
" "
)
print
(
f
"srt_output_str=
{
srt_output_str
}
"
)
print
(
f
"srt_output_str=
{
srt_output_str
}
"
)
print
(
f
"hf_output_str=
{
hf_output_str
}
"
)
print
(
f
"hf_output_str=
{
hf_output_str
}
"
)
rouge_l_scores
=
calculate_rouge_l
([
srt_output_str
],
[
hf_output_str
])
rouge_l_scores
=
calculate_rouge_l
([
srt_output_str
],
[
hf_output_str
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment