Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eefeb164
Unverified
Commit
eefeb164
authored
Apr 27, 2024
by
Austin Veselka
Committed by
GitHub
Apr 27, 2024
Browse files
[Kernel] Full Tensor Parallelism for LoRA Layers (#3524)
Co-authored-by:
Antoni Baum
<
antoni.baum@protonmail.com
>
parent
18d23f64
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
686 additions
and
111 deletions
+686
-111
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
+1
-0
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
+1
-0
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+78
-0
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
+1
-0
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
+1
-0
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
+1
-0
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
+1
-0
csrc/punica/bgmv/bgmv_impl.cuh
csrc/punica/bgmv/bgmv_impl.cuh
+4
-1
csrc/punica/bgmv/generator.py
csrc/punica/bgmv/generator.py
+1
-0
csrc/punica/punica_ops.cc
csrc/punica/punica_ops.cc
+1
-1
tests/lora/test_layers.py
tests/lora/test_layers.py
+23
-6
tests/lora/test_punica.py
tests/lora/test_punica.py
+49
-2
vllm/config.py
vllm/config.py
+1
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+10
-0
vllm/lora/fully_sharded_layers.py
vllm/lora/fully_sharded_layers.py
+262
-0
vllm/lora/layers.py
vllm/lora/layers.py
+146
-97
vllm/lora/models.py
vllm/lora/models.py
+3
-3
vllm/lora/punica.py
vllm/lora/punica.py
+43
-0
vllm/lora/utils.py
vllm/lora/utils.py
+59
-1
No files found.
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
View file @
eefeb164
...
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
nv_bfloat16
,
nv_bfloat16
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
nv_bfloat16
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
View file @
eefeb164
...
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
float
,
nv_bfloat16
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
nv_bfloat16
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_config.h
View file @
eefeb164
...
...
@@ -74,6 +74,74 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// Used for defining kernels going from the variety of
// dim in to the narrow dim out
// Using it for the fully sharded column
// parallel LoRA A which splits the rank dim
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, 128, narrow) \
f(in_T, out_T, W_T, 256, narrow) \
f(in_T, out_T, W_T, 512, narrow) \
f(in_T, out_T, W_T, 640, narrow) \
f(in_T, out_T, W_T, 768, narrow) \
f(in_T, out_T, W_T, 1024, narrow) \
f(in_T, out_T, W_T, 1152, narrow) \
f(in_T, out_T, W_T, 1280, narrow) \
f(in_T, out_T, W_T, 1536, narrow) \
f(in_T, out_T, W_T, 1728, narrow) \
f(in_T, out_T, W_T, 1792, narrow) \
f(in_T, out_T, W_T, 2048, narrow) \
f(in_T, out_T, W_T, 2304, narrow) \
f(in_T, out_T, W_T, 2560, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4608, narrow) \
f(in_T, out_T, W_T, 5120, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 8192, narrow) \
f(in_T, out_T, W_T, 9216, narrow) \
f(in_T, out_T, W_T, 10240, narrow) \
f(in_T, out_T, W_T, 11008, narrow) \
f(in_T, out_T, W_T, 12288, narrow) \
f(in_T, out_T, W_T, 13696, narrow) \
f(in_T, out_T, W_T, 13824, narrow) \
f(in_T, out_T, W_T, 14336, narrow) \
f(in_T, out_T, W_T, 15360, narrow) \
f(in_T, out_T, W_T, 16384, narrow) \
f(in_T, out_T, W_T, 20480, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32512, narrow) \
f(in_T, out_T, W_T, 32768, narrow) \
f(in_T, out_T, W_T, 33024, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 102400, narrow) \
f(in_T, out_T, W_T, 102656, narrow) \
f(in_T, out_T, W_T, 102912, narrow) \
f(in_T, out_T, W_T, 128000, narrow) \
f(in_T, out_T, W_T, 128256, narrow) \
f(in_T, out_T, W_T, 128512, narrow) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
...
...
@@ -81,4 +149,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
f(in_T, out_T, W_T, 8, 64) \
f(in_T, out_T, W_T, 16, 64) \
f(in_T, out_T, W_T, 32, 64) \
f(in_T, out_T, W_T, 64, 64)
// clang-format on
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
View file @
eefeb164
...
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_half
,
nv_half
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
nv_half
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
View file @
eefeb164
...
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
float
,
nv_half
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
nv_half
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
View file @
eefeb164
...
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_bfloat16
,
nv_bfloat16
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
float
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
View file @
eefeb164
...
...
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_half
,
nv_half
)
FOR_INST_BGMV_WIDE_NARROW
(
INST_BGMV_ONESIDE
,
float
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_impl.cuh
View file @
eefeb164
...
...
@@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
constexpr
int
tz
=
4
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
constexpr
(
feat_in
<
feat_out
)
{
if
constexpr
(
feat_in
<
=
feat_out
)
{
static_assert
(
feat_in
%
vec_size
==
0
);
constexpr
int
tx
=
feat_in
/
vec_size
;
...
...
@@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)
csrc/punica/bgmv/generator.py
View file @
eefeb164
...
...
@@ -10,6 +10,7 @@ TEMPLATE = """
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
"""
.
lstrip
()
# noqa: E501
for
input_dtype
in
DTYPES
:
...
...
csrc/punica/punica_ops.cc
View file @
eefeb164
...
...
@@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW
(
CASE
,
_
,
_
,
_
)
FOR_INST_BGMV_WIDE_NARROW
(
CASE_ONESIDE
,
_
,
_
,
_
)
#undef CASE
#undef CASE_ONESIDE
default:
return
false
;
}
return
true
;
}
...
...
tests/lora/test_layers.py
View file @
eefeb164
...
...
@@ -8,6 +8,10 @@ import torch
import
torch.nn.functional
as
F
from
vllm.config
import
LoRAConfig
from
vllm.lora.fully_sharded_layers
import
(
ColumnParallelLinearWithShardedLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedQKVParallelLinearWithShardedLora
,
RowParallelLinearWithShardedLoRA
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
ColumnParallelLinearWithLoRA
,
...
...
@@ -524,13 +528,16 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"orientation"
,
[
"row"
,
"column"
])
@
pytest
.
mark
.
parametrize
(
"fully_shard"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_linear_parallel
(
dist_init
,
num_loras
,
orientation
,
device
)
->
None
:
def
test_linear_parallel
(
dist_init
,
num_loras
,
orientation
,
fully_shard
,
device
)
->
None
:
torch
.
set_default_device
(
device
)
max_loras
=
8
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
fully_sharded_loras
=
fully_shard
,
lora_dtype
=
torch
.
float16
)
def
create_random_linear_parallel_layer
():
...
...
@@ -540,14 +547,17 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
bias
=
False
,
params_dtype
=
torch
.
float16
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
lora_linear
=
RowParallelLinearWithLoRA
(
linear
)
lora_linear
=
(
RowParallelLinearWithLoRA
(
linear
)
if
not
fully_shard
else
RowParallelLinearWithShardedLoRA
(
linear
))
else
:
linear
=
ColumnParallelLinear
(
4096
,
4096
,
bias
=
False
,
params_dtype
=
torch
.
float16
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
lora_linear
=
ColumnParallelLinearWithLoRA
(
linear
)
lora_linear
=
(
ColumnParallelLinearWithLoRA
(
linear
)
if
not
fully_shard
else
ColumnParallelLinearWithShardedLoRA
(
linear
))
lora_linear
.
create_lora_weights
(
max_loras
,
lora_config
)
return
linear
,
lora_linear
...
...
@@ -629,13 +639,16 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"repeats"
,
[
1
,
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"fully_shard"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_column_parallel_packed
(
dist_init
,
num_loras
,
repeats
,
device
)
->
None
:
def
test_column_parallel_packed
(
dist_init
,
num_loras
,
repeats
,
fully_shard
,
device
)
->
None
:
torch
.
set_default_device
(
device
)
max_loras
=
8
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
fully_sharded_loras
=
fully_shard
,
lora_dtype
=
torch
.
float16
)
def
create_column_parallel_packed_layer
():
...
...
@@ -644,7 +657,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
bias
=
False
,
params_dtype
=
torch
.
float16
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
lora_linear
=
MergedColumnParallelLinearWithLoRA
(
linear
)
lora_linear
=
(
MergedColumnParallelLinearWithLoRA
(
linear
)
if
not
fully_shard
else
MergedColumnParallelLinearWithShardedLoRA
(
linear
))
elif
repeats
==
3
:
linear
=
QKVParallelLinear
(
4096
,
64
,
...
...
@@ -652,7 +667,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
bias
=
False
,
params_dtype
=
torch
.
float16
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
lora_linear
=
MergedQKVParallelLinearWithLora
(
linear
)
lora_linear
=
(
MergedQKVParallelLinearWithLora
(
linear
)
if
not
fully_shard
else
MergedQKVParallelLinearWithShardedLora
(
linear
))
else
:
linear
=
QKVParallelLinear
(
4096
,
64
,
...
...
tests/lora/test_punica.py
View file @
eefeb164
...
...
@@ -34,11 +34,14 @@ def _lora_ref_impl(
for
i
,
lora_idx
in
zip
(
range
(
bs
),
indicies
.
cpu
().
tolist
()):
xi
=
x
[
i
].
unsqueeze
(
0
).
to
(
torch
.
float32
)
wa
=
wa_T_all
[
lora_idx
,
layer_idx
].
transpose
(
-
1
,
-
2
).
to
(
torch
.
float32
)
wb
=
wb_T_all
[
lora_idx
,
layer_idx
].
transpose
(
-
1
,
-
2
).
to
(
torch
.
float32
)
if
wb_T_all
is
not
None
:
wb
=
wb_T_all
[
lora_idx
,
layer_idx
].
transpose
(
-
1
,
-
2
).
to
(
torch
.
float32
)
tmp
=
xi
@
wa
y_stage_1
[
i
]
=
tmp
.
squeeze
(
0
)
y_final
[
i
]
+=
(
tmp
@
wb
).
squeeze
(
0
)
*
s
y_final
[
i
]
+=
((
tmp
@
wb
).
squeeze
(
0
)
*
s
if
wb_T_all
is
not
None
else
y_stage_1
[
i
])
return
y_final
,
y_stage_1
...
...
@@ -91,12 +94,56 @@ H1 = H2 = [
128000
,
128256
,
]
H2
=
[
64
]
+
H2
R
=
[
1
,
2
,
4
]
SEED
=
[
0xabcdabcd987
]
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
@
pytest
.
mark
.
parametrize
(
"dtype_str"
,
[
"float16"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"h1"
,
H1
)
@
pytest
.
mark
.
parametrize
(
"r"
,
R
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEED
)
@
torch
.
inference_mode
()
def
test_lora_a_extra_shapes
(
dtype_str
,
h1
,
r
,
seed
):
torch
.
manual_seed
(
seed
)
num_loras
=
4
num_layers
=
1
bs
=
32
dtype
=
getattr
(
torch
,
dtype_str
)
device
=
torch
.
device
(
"cuda"
)
wa_T_all
=
torch
.
randn
(
num_loras
,
num_layers
,
r
,
h1
,
dtype
=
dtype
,
device
=
device
)
indices
=
torch
.
randint
(
num_loras
,
(
bs
,
),
dtype
=
torch
.
long
,
device
=
device
)
for
layer_idx
in
range
(
num_layers
):
x
=
torch
.
randn
(
bs
,
h1
,
dtype
=
dtype
,
device
=
device
)
y
=
torch
.
randn
(
bs
,
r
,
dtype
=
dtype
,
device
=
device
)
y_ref
=
y
.
clone
()
_lora_ref_impl
(
y_ref
,
x
,
wa_T_all
,
None
,
indices
,
layer_idx
,
1.0
,
)
y_our
=
y
.
clone
()
punica
.
bgmv
(
y_our
,
x
,
wa_T_all
,
indices
,
layer_idx
,
1.0
)
assert_close
(
y_ref
,
y_our
)
@
pytest
.
mark
.
parametrize
(
"dtype_str"
,
[
"float16"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"h1"
,
H1
)
@
pytest
.
mark
.
parametrize
(
"h2"
,
H2
)
...
...
vllm/config.py
View file @
eefeb164
...
...
@@ -862,6 +862,7 @@ class SpeculativeConfig:
class
LoRAConfig
:
max_lora_rank
:
int
max_loras
:
int
fully_sharded_loras
:
bool
=
False
max_cpu_loras
:
Optional
[
int
]
=
None
lora_dtype
:
Optional
[
torch
.
dtype
]
=
None
lora_extra_vocab_size
:
int
=
256
...
...
vllm/engine/arg_utils.py
View file @
eefeb164
...
...
@@ -52,6 +52,7 @@ class EngineArgs:
enable_lora
:
bool
=
False
max_loras
:
int
=
1
max_lora_rank
:
int
=
16
fully_sharded_loras
:
bool
=
False
lora_extra_vocab_size
:
int
=
256
lora_dtype
=
'auto'
max_cpu_loras
:
Optional
[
int
]
=
None
...
...
@@ -376,6 +377,14 @@ class EngineArgs:
help
=
(
'Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'
))
parser
.
add_argument
(
'--fully-sharded-loras'
,
action
=
'store_true'
,
help
=
(
'By default, only half of the LoRA computation is '
'sharded with tensor parallelism. '
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'
))
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
EngineArgs
.
device
,
...
...
@@ -509,6 +518,7 @@ class EngineArgs:
lora_config
=
LoRAConfig
(
max_lora_rank
=
self
.
max_lora_rank
,
max_loras
=
self
.
max_loras
,
fully_sharded_loras
=
self
.
fully_sharded_loras
,
lora_extra_vocab_size
=
self
.
lora_extra_vocab_size
,
lora_dtype
=
self
.
lora_dtype
,
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
...
...
vllm/lora/fully_sharded_layers.py
0 → 100644
View file @
eefeb164
# pylint: disable=unused-argument
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.distributed.communication_op
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
vllm.lora.layers
import
(
ColumnParallelLinearWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
MergedQKVParallelLinearWithLora
,
RowParallelLinearWithLoRA
)
from
vllm.lora.punica
import
bgmv
,
dispatch_bgmv_low_level
if
TYPE_CHECKING
:
pass
def
_fully_sharded_can_replace
(
can_replace
):
"""
decorator which adds the condition of fully sharded loras
intended to wrap can_replace_layer()
"""
def
dec
(
*
args
,
**
kwargs
):
return
(
can_replace
(
*
args
,
**
kwargs
)
and
kwargs
[
'lora_config'
].
fully_sharded_loras
)
return
dec
# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
class
ColumnParallelLinearWithShardedLoRA
(
ColumnParallelLinearWithLoRA
):
"""
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
lora_a_stacked
.
shape
[
2
]
start_idx
=
tp_rank
*
shard_size
lora_a
=
lora_a
[:,
start_idx
:
start_idx
+
shard_size
]
return
lora_a
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
self
.
base_layer
,
x
,
bias
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
buffer
=
torch
.
zeros
((
x
.
shape
[
0
],
self
.
lora_a_stacked
.
shape
[
2
]),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
bgmv
(
buffer
,
x
,
self
.
lora_a_stacked
,
self
.
indices
[:
self
.
indices_len
[
0
]],
0
,
1.0
)
buffer
=
tensor_model_parallel_all_gather
(
buffer
)
bgmv
(
output
,
buffer
,
self
.
lora_b_stacked
,
self
.
indices
[:
self
.
indices_len
[
0
]],
0
,
1.0
)
# now have column partitioned output
output
=
output
.
view
(
*
out_orig_shape
)
return
output
@
classmethod
@
_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
source_layer
=
source_layer
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_list
,
model_config
=
model_config
,
decorate
=
False
,
)
def
_mcp_apply_weights
(
x
,
bias
,
layer
):
"""
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
LoRa weight application method.
The main difference is the step by shard_size for lora_b which can
vary for QKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
n
=
len
(
layer
.
lora_a_stacked
)
output
=
layer
.
base_layer
.
linear_method
.
apply_weights
(
layer
.
base_layer
,
x
,
bias
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
buffers
=
torch
.
zeros
((
n
,
x
.
shape
[
0
],
layer
.
lora_a_stacked
[
0
].
shape
[
2
]),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
for
idx
in
range
(
n
):
bgmv
(
buffers
[
idx
],
x
,
layer
.
lora_a_stacked
[
idx
],
layer
.
indices
[:
layer
.
indices_len
[
0
]],
0
,
1.0
)
buffers
=
tensor_model_parallel_all_gather
(
buffers
)
left_offset
=
0
for
idx
in
range
(
n
):
shard_size
=
layer
.
lora_b_stacked
[
idx
].
shape
[
2
]
dispatch_bgmv_low_level
(
output
,
buffers
[
idx
],
layer
.
lora_b_stacked
[
idx
],
layer
.
indices
[:
layer
.
indices_len
[
0
]],
0
,
1.0
,
left_offset
,
shard_size
)
left_offset
+=
shard_size
output
=
output
.
view
(
*
out_orig_shape
)
# now have column partitioned and packed output
return
output
class
MergedColumnParallelLinearWithShardedLoRA
(
MergedColumnParallelLinearWithLoRA
):
"""
Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
output_shard_size
=
self
.
lora_a_stacked
[
0
].
shape
[
2
]
output_start_idx
=
self
.
tp_rank
*
output_shard_size
lora_a
=
[
lora_a
[
i
][:,
output_start_idx
:
output_start_idx
+
output_shard_size
]
for
i
in
range
(
2
)
]
return
lora_a
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
_mcp_apply_weights
(
x
,
bias
,
self
)
@
classmethod
@
_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
source_layer
=
source_layer
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_list
,
model_config
=
model_config
,
decorate
=
False
,
)
class
MergedQKVParallelLinearWithShardedLora
(
MergedQKVParallelLinearWithLora
):
"""
Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
shard_size
=
[
self
.
lora_a_stacked
[
i
].
shape
[
2
]
for
i
in
range
(
3
)]
start_idx
=
[
self
.
tp_rank
*
shard_size
[
i
]
for
i
in
range
(
3
)]
lora_a
=
[
lora_a
[
i
][:,
start_idx
[
i
]:
start_idx
[
i
]
+
shard_size
[
i
]]
if
lora_a
[
i
]
is
not
None
else
None
for
i
in
range
(
3
)
]
return
lora_a
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
_mcp_apply_weights
(
x
,
bias
,
self
)
@
classmethod
@
_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
source_layer
=
source_layer
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_list
,
model_config
=
model_config
,
decorate
=
False
,
)
class
RowParallelLinearWithShardedLoRA
(
RowParallelLinearWithLoRA
):
"""
Differs from RowParallelLinearWithLoRA by slicing the
LoRA B's also.
Based on S-LoRA, slicing happens along the output dim.
This yields a combined partial sum from the row parallel base
layer and column partitioned output from the LoRA.
"""
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shard_size
=
self
.
lora_b_stacked
.
shape
[
2
]
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
lora_b
=
lora_b
[:,
start_idx
:
end_idx
]
return
lora_b
def
apply_weights
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
self
.
base_layer
,
x
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
buffer
=
torch
.
zeros
((
x
.
shape
[
0
],
self
.
lora_a_stacked
.
shape
[
2
]),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
bgmv
(
buffer
,
x
,
self
.
lora_a_stacked
,
self
.
indices
[:
self
.
indices_len
[
0
]],
0
,
1.0
)
buffer
=
tensor_model_parallel_all_reduce
(
buffer
)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
# tensor, which is a partial sum due to row parallel. All that
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
shard_size
=
self
.
lora_b_stacked
.
shape
[
2
]
start_idx
=
self
.
tp_rank
*
shard_size
dispatch_bgmv_low_level
(
output
,
buffer
,
self
.
lora_b_stacked
,
self
.
indices
[:
self
.
indices_len
[
0
]],
0
,
1.0
,
start_idx
,
shard_size
)
output
=
output
.
view
(
*
out_orig_shape
)
return
output
@
classmethod
@
_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
source_layer
=
source_layer
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_list
,
model_config
=
model_config
,
decorate
=
False
,
)
vllm/lora/layers.py
View file @
eefeb164
# pylint: disable=unused-argument
import
inspect
import
math
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Typ
e
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tupl
e
import
torch
import
torch.nn
as
nn
...
...
@@ -16,6 +15,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_gather
)
from
vllm.distributed.utils
import
divide
from
vllm.lora.punica
import
add_lora
,
add_lora_slice
,
bgmv
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
...
...
@@ -23,7 +23,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
VocabParallelEmbedding
)
if
TYPE_CHECKING
:
pass
...
...
@@ -45,6 +45,21 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
raise
ValueError
(
f
"Unsupported base layer:
{
base_layer
}
"
)
def
_not_fully_sharded_can_replace
(
can_replace
):
"""
decorator which adds the condition of not using fully sharded loras
intended to wrap can_replace_layer()
"""
def
dec
(
*
args
,
**
kwargs
):
decorate
=
kwargs
.
pop
(
'decorate'
)
if
'decorate'
in
kwargs
else
True
condition
=
(
not
kwargs
[
'lora_config'
].
fully_sharded_loras
if
decorate
else
True
)
return
can_replace
(
*
args
,
**
kwargs
)
and
condition
return
dec
def
_apply_lora
(
x
:
torch
.
Tensor
,
lora_a_stacked
:
torch
.
Tensor
,
...
...
@@ -130,6 +145,14 @@ class LoRAMapping:
class
BaseLayerWithLoRA
(
nn
.
Module
):
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Slice lora a if splitting for tensor parallelism."""
...
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Slice lora b if splitting with tensor parallelism."""
...
def
create_lora_weights
(
self
,
max_loras
:
int
,
...
...
@@ -317,6 +340,11 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
class
ColumnParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
"""
LoRA on top of ColumnParallelLinear layer.
LoRA B is sliced for tensor parallelism.
"""
def
__init__
(
self
,
base_layer
:
ColumnParallelLinear
)
->
None
:
super
().
__init__
()
...
...
@@ -331,10 +359,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
self
.
lora_config
=
lora_config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
lora_a_output_size_per_partition
=
(
lora_config
.
max_lora_rank
if
not
lora_config
.
fully_sharded_loras
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
))
self
.
lora_a_stacked
=
torch
.
zeros
(
max_loras
,
1
,
lora_
config
.
max_lora_rank
,
lora_
a_output_size_per_partition
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
...
...
@@ -357,6 +390,17 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_b_stacked
[
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
output_dim
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_b
=
lora_b
[:,
start_idx
:
end_idx
]
return
lora_b
def
set_lora
(
self
,
index
:
int
,
...
...
@@ -365,12 +409,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
):
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
output_dim
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_b
=
lora_b
[:,
start_idx
:
end_idx
]
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
...
...
@@ -426,6 +469,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
return
output
,
output_bias
@
classmethod
@
_not_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
...
...
@@ -451,6 +495,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
self
.
lora_config
=
lora_config
n_slices
=
2
if
not
(
len
(
self
.
base_layer
.
output_sizes
)
==
n_slices
and
self
.
base_layer
.
output_sizes
[
0
]
...
...
@@ -459,12 +504,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"LoRAColumnParallelLinear2Slice requires 2 slices with "
"the same size."
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
lora_a_output_size_per_partition
=
(
lora_config
.
max_lora_rank
if
not
lora_config
.
fully_sharded_loras
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
))
self
.
lora_a_stacked
=
tuple
(
torch
.
zeros
(
max_loras
,
1
,
lora_
config
.
max_lora_rank
,
lora_
a_output_size_per_partition
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
...
...
@@ -489,6 +539,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self
.
lora_b_stacked
[
0
][
index
]
=
0
self
.
lora_b_stacked
[
1
][
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
shard_size
=
self
.
output_dim
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
lora_b
=
[
lora_b
[
0
][:,
start_idx
:
end_idx
],
lora_b
[
1
][:,
start_idx
:
end_idx
]
]
return
lora_b
def
set_lora
(
self
,
index
:
int
,
...
...
@@ -499,13 +561,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
output_dim
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_b
=
lora_b
[
0
][:,
start_idx
:
end_idx
],
lora_b
[
1
][:,
start_idx
:
end_idx
]
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
if
lora_a
[
0
]
is
not
None
:
self
.
lora_a_stacked
[
0
][
...
...
@@ -536,6 +593,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
return
output
@
classmethod
@
_not_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
...
...
@@ -627,21 +685,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
self
.
lora_config
=
lora_config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
q_proj_shard_size
=
(
self
.
base_layer
.
num_heads
*
self
.
base_layer
.
head_size
)
self
.
kv_proj_shard_size
=
(
self
.
base_layer
.
num_kv_heads
*
self
.
base_layer
.
head_size
)
self
.
q_shard_id
=
tp_rank
self
.
kv_shard_id
=
tp_rank
//
self
.
base_layer
.
num_kv_head_replicas
self
.
q_shard_id
=
self
.
tp_rank
self
.
kv_shard_id
=
self
.
tp_rank
//
self
.
base_layer
.
num_kv_head_replicas
lora_a_output_size_per_partition
=
(
lora_config
.
max_lora_rank
if
not
lora_config
.
fully_sharded_loras
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
))
# q, k, v
self
.
lora_a_stacked
=
(
torch
.
zeros
(
max_loras
,
1
,
lora_
config
.
max_lora_rank
,
lora_
a_output_size_per_partition
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
...
...
@@ -649,7 +711,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch
.
zeros
(
max_loras
,
1
,
lora_
config
.
max_lora_rank
,
lora_
a_output_size_per_partition
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
...
...
@@ -657,7 +719,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch
.
zeros
(
max_loras
,
1
,
lora_
config
.
max_lora_rank
,
lora_
a_output_size_per_partition
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
...
...
@@ -705,6 +767,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
lora_a_stacked
[
2
][
index
]
=
0
self
.
lora_b_stacked
[
2
][
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
if
lora_b
[
0
]
is
not
None
:
lora_b_q
=
lora_b
[
0
][:,
self
.
q_proj_shard_size
*
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
(
self
.
q_shard_id
+
1
)]
if
lora_b
[
1
]
is
not
None
:
lora_b_k
=
lora_b
[
1
][:,
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)]
if
lora_b
[
2
]
is
not
None
:
lora_b_v
=
lora_b
[
2
][:,
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)]
lora_b
=
[
lora_b_q
,
lora_b_k
,
lora_b_v
]
return
lora_b
def
set_lora
(
self
,
index
:
int
,
...
...
@@ -715,40 +796,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
if
lora_b
[
0
]
is
not
None
:
lora_b_q
=
lora_b
[
0
][:,
self
.
q_proj_shard_size
*
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
(
self
.
q_shard_id
+
1
)]
lora_b_q
=
lora_b
[
0
]
self
.
lora_b_stacked
[
0
][
index
,
0
,
:
lora_b_q
.
shape
[
1
],
:
lora_b_q
.
shape
[
0
]].
copy_
(
lora_b_q
.
T
,
non_blocking
=
True
)
if
lora_b
[
1
]
is
not
None
:
lora_b_k
=
lora_b
[
1
][:,
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)]
lora_b_k
=
lora_b
[
1
]
self
.
lora_b_stacked
[
1
][
index
,
0
,
:
lora_b_k
.
shape
[
1
],
:
lora_b_k
.
shape
[
0
]].
copy_
(
lora_b_k
.
T
,
non_blocking
=
True
)
if
lora_b
[
2
]
is
not
None
:
lora_b_v
=
lora_b
[
2
][:,
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)]
lora_b_v
=
lora_b
[
2
]
self
.
lora_b_stacked
[
2
][
index
,
0
,
:
lora_b_v
.
shape
[
1
],
:
lora_b_v
.
shape
[
0
]].
copy_
(
lora_b_v
.
T
,
non_blocking
=
True
)
else
:
if
lora_b
[
0
]
is
not
None
:
self
.
lora_b_stacked
[
0
][
index
,
0
,
:
lora_b
[
0
].
shape
[
1
],
:
lora_b
[
0
].
shape
[
0
]].
copy_
(
lora_b
[
0
].
T
,
non_blocking
=
True
)
if
lora_b
[
1
]
is
not
None
:
self
.
lora_b_stacked
[
1
][
index
,
0
,
:
lora_b
[
1
].
shape
[
1
],
:
lora_b
[
1
].
shape
[
0
]].
copy_
(
lora_b
[
1
].
T
,
non_blocking
=
True
)
if
lora_b
[
2
]
is
not
None
:
self
.
lora_b_stacked
[
2
][
index
,
0
,
:
lora_b
[
2
].
shape
[
1
],
:
lora_b
[
2
].
shape
[
0
]].
copy_
(
lora_b
[
2
].
T
,
non_blocking
=
True
)
if
lora_a
[
0
]
is
not
None
:
self
.
lora_a_stacked
[
0
][
...
...
@@ -777,6 +842,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
return
output
@
classmethod
@
_not_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
...
...
@@ -798,6 +864,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
self
.
lora_config
=
lora_config
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
lora_a_stacked
=
torch
.
zeros
(
(
max_loras
,
...
...
@@ -808,11 +876,16 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
tp_size
=
get_tensor_model_parallel_world_size
()
lora_b_output_size_per_partition
=
(
self
.
output_size
if
not
lora_config
.
fully_sharded_loras
else
divide
(
self
.
output_size
,
tp_size
))
self
.
lora_b_stacked
=
torch
.
zeros
(
(
max_loras
,
1
,
self
.
output_size
,
lora_b_
output_size
_per_partition
,
lora_config
.
max_lora_rank
,
),
dtype
=
lora_config
.
lora_dtype
,
...
...
@@ -826,6 +899,17 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_b_stacked
[
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
input_size
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_a
=
lora_a
[
start_idx
:
end_idx
,
:]
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
lora_b
def
set_lora
(
self
,
index
:
int
,
...
...
@@ -834,12 +918,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
):
self
.
reset_lora
(
index
)
if
self
.
base_layer
.
tp_size
>
1
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
input_size
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_a
=
lora_a
[
start_idx
:
end_idx
,
:]
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b
=
self
.
slice_lora_b
(
lora_b
)
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
...
...
@@ -915,6 +997,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self
.
base_layer
,
"weight"
)
else
self
.
base_layer
.
qweight
@
classmethod
@
_not_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
...
...
@@ -1096,37 +1179,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
# Special handling for the LogitsProcessor.
return
False
_all_lora_classes
:
Set
[
Type
[
BaseLayerWithLoRA
]]
=
{
cls
for
cls
in
globals
().
values
()
if
inspect
.
isclass
(
cls
)
and
issubclass
(
cls
,
BaseLayerWithLoRA
)
and
cls
is
not
BaseLayerWithLoRA
}
def
from_layer
(
layer
:
nn
.
Module
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
nn
.
Module
:
for
lora_cls
in
_all_lora_classes
:
if
lora_cls
.
can_replace_layer
(
layer
,
lora_config
,
packed_modules_list
,
model_config
):
ret
=
lora_cls
(
layer
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
return
layer
def
from_layer_logits_processor
(
layer
:
LogitsProcessor
,
lm_head
:
ParallelLMHead
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
LogitsProcessorWithLoRA
:
ret
=
LogitsProcessorWithLoRA
(
layer
,
lm_head
.
embedding_dim
,
lm_head
.
weight
.
dtype
,
lm_head
.
weight
.
device
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
vllm/lora/models.py
View file @
eefeb164
...
...
@@ -11,10 +11,10 @@ from torch import nn
from
vllm.config
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
LoRAMapping
,
from_layer
,
from_layer_logits_processor
)
from
vllm.lora.layers
import
BaseLayerWithLoRA
,
LoRAMapping
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.utils
import
parse_fine_tuned_lora_name
,
replace_submodule
from
vllm.lora.utils
import
(
from_layer
,
from_layer_logits_processor
,
parse_fine_tuned_lora_name
,
replace_submodule
)
from
vllm.utils
import
LRUCache
,
is_pin_memory_available
logger
=
init_logger
(
__name__
)
...
...
vllm/lora/punica.py
View file @
eefeb164
...
...
@@ -49,6 +49,49 @@ def bgmv(
punica_kernels
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
def
dispatch_bgmv_low_level
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
LongTensor
,
layer_idx
:
int
,
scale
:
float
,
y_offset
:
int
,
y_slice_size
:
int
):
"""
Same as `bgmv` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
all of the transposed LoRA matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
try
:
import
vllm._punica_C
as
punica_kernels
except
ImportError
as
e
:
_raise_import_error
(
e
)
punica_kernels
.
dispatch_bgmv_low_level
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
,
x
.
size
(
1
),
y_slice_size
,
y_offset
,
)
def
add_lora
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
...
...
vllm/lora/utils.py
View file @
eefeb164
from
typing
import
Tupl
e
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Typ
e
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.lora.fully_sharded_layers
import
(
ColumnParallelLinearWithShardedLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedQKVParallelLinearWithShardedLora
,
RowParallelLinearWithShardedLoRA
)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
ColumnParallelLinearWithLoRA
,
LogitsProcessorWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
MergedQKVParallelLinearWithLora
,
QKVParallelLinearWithLora
,
RowParallelLinearWithLoRA
,
VocabParallelEmbeddingWithLoRA
)
# yapf: enable
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
logger
=
init_logger
(
__name__
)
_all_lora_classes
:
Set
[
Type
[
BaseLayerWithLoRA
]]
=
{
VocabParallelEmbeddingWithLoRA
,
ColumnParallelLinearWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
QKVParallelLinearWithLora
,
MergedQKVParallelLinearWithLora
,
RowParallelLinearWithLoRA
,
LogitsProcessorWithLoRA
,
ColumnParallelLinearWithShardedLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedQKVParallelLinearWithShardedLora
,
RowParallelLinearWithShardedLoRA
}
def
from_layer
(
layer
:
nn
.
Module
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
nn
.
Module
:
for
lora_cls
in
_all_lora_classes
:
# specifying kwargs so they can be easily accessed in decorator
if
lora_cls
.
can_replace_layer
(
source_layer
=
layer
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_list
,
model_config
=
model_config
):
ret
=
lora_cls
(
layer
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
return
layer
def
from_layer_logits_processor
(
layer
:
LogitsProcessor
,
lm_head
:
ParallelLMHead
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
LogitsProcessorWithLoRA
:
ret
=
LogitsProcessorWithLoRA
(
layer
,
lm_head
.
embedding_dim
,
lm_head
.
weight
.
dtype
,
lm_head
.
weight
.
device
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
def
replace_submodule
(
model
:
nn
.
Module
,
module_name
:
str
,
new_module
:
nn
.
Module
)
->
nn
.
Module
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment