Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
273690a5
Unverified
Commit
273690a5
authored
Sep 23, 2025
by
Jee Jee Li
Committed by
GitHub
Sep 23, 2025
Browse files
[Core] Optimize LoRA weight loading (#25403)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
231c2c63
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
83 additions
and
83 deletions
+83
-83
tests/lora/test_layers.py
tests/lora/test_layers.py
+14
-12
tests/lora/test_lora_manager.py
tests/lora/test_lora_manager.py
+6
-6
tests/lora/utils.py
tests/lora/utils.py
+4
-4
vllm/lora/layers/base_linear.py
vllm/lora/layers/base_linear.py
+5
-5
vllm/lora/layers/column_parallel_linear.py
vllm/lora/layers/column_parallel_linear.py
+33
-34
vllm/lora/layers/logits_processor.py
vllm/lora/layers/logits_processor.py
+4
-4
vllm/lora/layers/row_parallel_linear.py
vllm/lora/layers/row_parallel_linear.py
+2
-2
vllm/lora/layers/vocal_parallel_embedding.py
vllm/lora/layers/vocal_parallel_embedding.py
+6
-4
vllm/lora/lora_weights.py
vllm/lora/lora_weights.py
+2
-2
vllm/lora/models.py
vllm/lora/models.py
+7
-10
No files found.
tests/lora/test_layers.py
View file @
273690a5
...
@@ -164,8 +164,8 @@ def populate_loras(
...
@@ -164,8 +164,8 @@ def populate_loras(
weight
=
layer_weights
,
weight
=
layer_weights
,
generate_embeddings_tensor
=
generate_embeddings_tensor
,
generate_embeddings_tensor
=
generate_embeddings_tensor
,
)
)
sublora
.
lora_b
=
sublora
.
lora_b
[
:,
(
sublora_len
*
sublora
.
lora_b
=
sublora
.
lora_b
[(
sublora_len
*
i
):(
sublora_len
*
(
i
+
1
))]
i
):(
sublora_len
*
(
i
+
1
))
,
:
]
sublora
.
optimize
()
sublora
.
optimize
()
subloras
.
append
(
sublora
)
subloras
.
append
(
sublora
)
...
@@ -304,9 +304,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
...
@@ -304,9 +304,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
result
=
embedding
(
input_
)
result
=
embedding
(
input_
)
after_a
=
F
.
embedding
(
after_a
=
F
.
embedding
(
input_
,
input_
,
lora
.
lora_a
,
lora
.
lora_a
.
T
,
)
)
result
+=
(
after_a
@
lora
.
lora_b
)
result
+=
(
after_a
@
lora
.
lora_b
.
T
)
expected_results
.
append
(
result
)
expected_results
.
append
(
result
)
expected_result
=
torch
.
cat
(
expected_results
)
expected_result
=
torch
.
cat
(
expected_results
)
...
@@ -445,9 +445,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
...
@@ -445,9 +445,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
result
=
expanded_embedding
(
input_
)
result
=
expanded_embedding
(
input_
)
after_a
=
F
.
embedding
(
after_a
=
F
.
embedding
(
original_input_
,
original_input_
,
lora
.
lora_a
,
lora
.
lora_a
.
T
,
)
)
result
+=
(
after_a
@
lora
.
lora_b
)
result
+=
(
after_a
@
lora
.
lora_b
.
T
)
expected_results
.
append
(
result
)
expected_results
.
append
(
result
)
expected_result
=
torch
.
cat
(
expected_results
)
expected_result
=
torch
.
cat
(
expected_results
)
...
@@ -575,7 +575,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
...
@@ -575,7 +575,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
lm_head
=
linear
,
lm_head
=
linear
,
embedding_bias
=
None
)
embedding_bias
=
None
)
result
[:,
vocab_size
+
embeddings_tensor_len
:]
=
float
(
"-inf"
)
result
[:,
vocab_size
+
embeddings_tensor_len
:]
=
float
(
"-inf"
)
result
+=
input_
@
lora
.
lora_a
@
lora
.
lora_b
*
lora
.
scaling
result
+=
input_
@
lora
.
lora_a
.
T
@
lora
.
lora_b
.
T
*
lora
.
scaling
expected_results
.
append
(
result
)
expected_results
.
append
(
result
)
expected_result
=
torch
.
cat
(
expected_results
)
expected_result
=
torch
.
cat
(
expected_results
)
logits_processor
.
org_vocab_size
=
vocab_size
logits_processor
.
org_vocab_size
=
vocab_size
...
@@ -692,9 +692,10 @@ def test_linear_replicated(
...
@@ -692,9 +692,10 @@ def test_linear_replicated(
expected_results
:
list
[
torch
.
Tensor
]
=
[]
expected_results
:
list
[
torch
.
Tensor
]
=
[]
for
input_
,
lora_id
in
zip
(
inputs
,
prompt_mapping
):
for
input_
,
lora_id
in
zip
(
inputs
,
prompt_mapping
):
lora
=
lora_dict
[
lora_id
]
lora
=
lora_dict
[
lora_id
]
result
=
linear
(
input_
)[
0
]
result
=
linear
(
input_
)[
0
]
result
+=
input_
@
lora
.
lora_a
@
lora
.
lora_b
*
lora
.
scaling
result
+=
input_
@
lora
.
lora_a
.
T
@
lora
.
lora_b
.
T
*
lora
.
scaling
expected_results
.
append
(
result
)
expected_results
.
append
(
result
)
expected_result
=
torch
.
cat
(
expected_results
)
expected_result
=
torch
.
cat
(
expected_results
)
...
@@ -817,7 +818,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
...
@@ -817,7 +818,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
for
input_
,
lora_id
in
zip
(
inputs
,
prompt_mapping
):
for
input_
,
lora_id
in
zip
(
inputs
,
prompt_mapping
):
lora
=
lora_dict
[
lora_id
]
lora
=
lora_dict
[
lora_id
]
result
=
linear
(
input_
)[
0
]
result
=
linear
(
input_
)[
0
]
result
+=
input_
@
lora
.
lora_a
@
lora
.
lora_b
*
lora
.
scaling
result
+=
input_
@
lora
.
lora_a
.
T
@
lora
.
lora_b
.
T
*
lora
.
scaling
expected_results
.
append
(
result
)
expected_results
.
append
(
result
)
expected_result
=
torch
.
cat
(
expected_results
)
expected_result
=
torch
.
cat
(
expected_results
)
...
@@ -965,9 +966,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
...
@@ -965,9 +966,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
result
=
linear
(
input_
)[
0
]
result
=
linear
(
input_
)[
0
]
subloras
=
sublora_dict
[
lora_id
]
subloras
=
sublora_dict
[
lora_id
]
for
i
,
sublora
in
enumerate
(
subloras
):
for
i
,
sublora
in
enumerate
(
subloras
):
result
[:,
sublora
.
lora_b
.
shape
[
1
]
*
i
:
sublora
.
lora_b
.
shape
[
1
]
*
result
[:,
sublora
.
lora_b
.
shape
[
0
]
*
i
:
sublora
.
lora_b
.
shape
[
0
]
*
(
i
+
1
)]
+=
(
input_
@
sublora
.
lora_a
@
sublora
.
lora_b
*
(
i
+
1
)]
+=
(
sublora
.
scaling
)
input_
@
sublora
.
lora_a
.
T
@
sublora
.
lora_b
.
T
*
sublora
.
scaling
)
expected_results
.
append
(
result
)
expected_results
.
append
(
result
)
expected_result
=
torch
.
cat
(
expected_results
)
expected_result
=
torch
.
cat
(
expected_results
)
...
...
tests/lora/test_lora_manager.py
View file @
273690a5
...
@@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device):
...
@@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device):
assert
lora
.
lora_b
is
not
None
assert
lora
.
lora_b
is
not
None
assert
lora
.
lora_a
.
device
==
torch
.
device
(
device
)
assert
lora
.
lora_a
.
device
==
torch
.
device
(
device
)
assert
lora
.
lora_b
.
device
==
torch
.
device
(
device
)
assert
lora
.
lora_b
.
device
==
torch
.
device
(
device
)
assert
(
lora
.
lora_a
.
shape
[
1
]
==
lora
.
lora_b
.
shape
[
0
]
assert
(
lora
.
lora_a
.
shape
[
0
]
==
lora
.
lora_b
.
shape
[
1
]
),
f
"
{
lora
.
lora_a
.
shape
=
}
,
{
lora
.
lora_b
.
shape
=
}
"
),
f
"
{
lora
.
lora_a
.
shape
=
}
,
{
lora
.
lora_b
.
shape
=
}
"
assert
lora
.
lora_a
.
shape
[
1
]
==
8
assert
lora
.
lora_a
.
shape
[
0
]
==
8
embeddings_module
=
next
(
embeddings_module
=
next
(
(
k
for
k
in
EMBEDDING_MODULES
if
k
in
module_name
),
None
)
(
k
for
k
in
EMBEDDING_MODULES
if
k
in
module_name
),
None
)
if
embeddings_module
:
if
embeddings_module
:
...
@@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
...
@@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
name
,
name
,
8
,
8
,
16
,
16
,
torch
.
rand
([
w
.
shape
[
1
]
,
8
],
device
=
device
),
torch
.
rand
([
8
,
w
.
shape
[
1
]],
device
=
device
),
torch
.
rand
([
8
,
w
.
shape
[
0
]],
device
=
device
),
torch
.
rand
([
w
.
shape
[
0
]
,
8
],
device
=
device
),
)
)
return
LoRAModel
(
lora_id
,
8
,
loras
)
return
LoRAModel
(
lora_id
,
8
,
loras
)
...
@@ -109,8 +109,8 @@ def create_packed_lora(
...
@@ -109,8 +109,8 @@ def create_packed_lora(
replaced_module_name
,
replaced_module_name
,
8
,
8
,
16
,
16
,
torch
.
rand
([
w
.
shape
[
1
]
,
8
],
device
=
device
),
torch
.
rand
([
8
,
w
.
shape
[
1
]],
device
=
device
),
torch
.
rand
([
8
,
w
.
shape
[
0
]
//
len
(
replaced_module_names
)],
torch
.
rand
([
w
.
shape
[
0
]
//
len
(
replaced_module_names
)
,
8
],
device
=
device
),
device
=
device
),
)
)
return
LoRAModel
(
lora_id
,
8
,
loras
)
return
LoRAModel
(
lora_id
,
8
,
loras
)
...
...
tests/lora/utils.py
View file @
273690a5
...
@@ -36,10 +36,10 @@ class DummyLoRAManager:
...
@@ -36,10 +36,10 @@ class DummyLoRAManager:
module_name
,
module_name
,
rank
=
rank
,
rank
=
rank
,
lora_alpha
=
1
,
lora_alpha
=
1
,
lora_a
=
torch
.
rand
([
weight
.
shape
[
1
]
,
rank
],
lora_a
=
torch
.
rand
([
rank
,
weight
.
shape
[
1
]],
dtype
=
weight
.
dtype
,
dtype
=
weight
.
dtype
,
device
=
self
.
_device
),
device
=
self
.
_device
),
lora_b
=
torch
.
rand
([
rank
,
weight
.
shape
[
0
]],
lora_b
=
torch
.
rand
([
weight
.
shape
[
0
]
,
rank
],
dtype
=
weight
.
dtype
,
dtype
=
weight
.
dtype
,
device
=
self
.
_device
),
device
=
self
.
_device
),
)
)
...
@@ -67,8 +67,8 @@ class DummyLoRAManager:
...
@@ -67,8 +67,8 @@ class DummyLoRAManager:
module_name
,
module_name
,
rank
=
rank
,
rank
=
rank
,
lora_alpha
=
1
,
lora_alpha
=
1
,
lora_a
=
torch
.
rand
([
input_dim
,
rank
],
device
=
"cuda"
),
lora_a
=
torch
.
rand
([
rank
,
input_dim
],
device
=
"cuda"
),
lora_b
=
torch
.
rand
([
rank
,
out
put_dim
],
device
=
"cuda"
),
lora_b
=
torch
.
rand
([
output_dim
,
in
put_dim
],
device
=
"cuda"
),
embeddings_tensor
=
embeddings_tensor
,
embeddings_tensor
=
embeddings_tensor
,
)
)
self
.
set_module_lora
(
module_name
,
lora
)
self
.
set_module_lora
(
module_name
,
lora
)
...
...
vllm/lora/layers/base_linear.py
View file @
273690a5
...
@@ -121,18 +121,18 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
...
@@ -121,18 +121,18 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
lora_bias
=
self
.
slice_bias
(
lora_bias
)
lora_bias
=
self
.
slice_bias
(
lora_bias
)
self
.
lora_a_stacked
[
0
][
index
,
self
.
lora_a_stacked
[
0
][
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
0
,
:
lora_a
.
shape
[
0
],
:
lora_a
.
shape
[
1
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
lora_a
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
0
][
index
,
self
.
lora_b_stacked
[
0
][
index
,
0
,
:
lora_b
.
shape
[
1
],
:
lora_b
.
shape
[
0
]].
copy_
(
0
,
:
lora_b
.
shape
[
0
],
:
lora_b
.
shape
[
1
]].
copy_
(
lora_b
.
T
,
non_blocking
=
True
)
lora_b
,
non_blocking
=
True
)
if
lora_bias
is
not
None
:
if
lora_bias
is
not
None
:
self
.
lora_bias_stacked
=
cast
(
tuple
[
torch
.
Tensor
,
...],
self
.
lora_bias_stacked
=
cast
(
tuple
[
torch
.
Tensor
,
...],
self
.
lora_bias_stacked
)
self
.
lora_bias_stacked
)
assert
len
(
self
.
lora_bias_stacked
)
assert
len
(
self
.
lora_bias_stacked
)
self
.
lora_bias_stacked
[
0
][
index
,
0
,
:
lora_bias
.
shape
[
0
]].
copy_
(
self
.
lora_bias_stacked
[
0
][
index
,
0
,
:
lora_bias
.
shape
[
0
]].
copy_
(
lora_bias
.
T
,
non_blocking
=
True
)
lora_bias
,
non_blocking
=
True
)
def
apply
(
self
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
vllm/lora/layers/column_parallel_linear.py
View file @
273690a5
...
@@ -99,13 +99,13 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -99,13 +99,13 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
if
self
.
is_merged_col_linear
:
if
self
.
is_merged_col_linear
:
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
output_size
//
2
shard_size
=
self
.
output_size
//
2
offset
=
lora_b
.
shape
[
-
1
]
//
2
offset
=
lora_b
.
shape
[
0
]
//
2
left_weight
=
lora_b
[
:,
tp_rank
*
shard_size
:(
tp_rank
+
1
)
*
left_weight
=
lora_b
[
tp_rank
*
shard_size
:(
tp_rank
+
1
)
*
shard_size
]
shard_size
,
:
]
right_weight
=
lora_b
[
:,
offset
+
tp_rank
*
shard_size
:
offset
+
right_weight
=
lora_b
[
offset
+
tp_rank
*
shard_size
:
offset
+
(
tp_rank
+
1
)
*
shard_size
]
(
tp_rank
+
1
)
*
shard_size
,
:
]
lora_b
=
torch
.
cat
([
left_weight
,
right_weight
],
dim
=
1
)
lora_b
=
torch
.
cat
([
left_weight
,
right_weight
],
dim
=
0
)
# Applicable to cases where the base_layer is
# Applicable to cases where the base_layer is
# ColumnParallelLinear.
# ColumnParallelLinear.
else
:
else
:
...
@@ -113,7 +113,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -113,7 +113,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
shard_size
=
self
.
output_size
shard_size
=
self
.
output_size
start_idx
=
tensor_model_parallel_rank
*
shard_size
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_b
=
lora_b
[
:,
start_idx
:
end_idx
]
lora_b
=
lora_b
[
start_idx
:
end_idx
,
:
]
return
lora_b
return
lora_b
def
slice_bias
(
self
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
slice_bias
(
self
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -251,9 +251,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -251,9 +251,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
for
i
,
(
shard_id
,
shard_size
)
in
enumerate
(
for
i
,
(
shard_id
,
shard_size
)
in
enumerate
(
zip
(
self
.
output_ids
,
self
.
output_slices
)):
zip
(
self
.
output_ids
,
self
.
output_slices
)):
if
(
lora_b_i
:
=
lora_b
[
i
])
is
not
None
:
if
(
lora_b_i
:
=
lora_b
[
i
])
is
not
None
:
sliced_lora_b
[
i
]
=
lora_b_i
[:,
sliced_lora_b
[
i
]
=
lora_b_i
[
shard_size
*
shard_id
:
shard_size
*
shard_size
*
shard_id
:
shard_size
*
(
shard_id
+
1
),
:]
(
shard_id
+
1
)]
return
sliced_lora_b
return
sliced_lora_b
def
slice_bias
(
def
slice_bias
(
...
@@ -285,12 +284,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -285,12 +284,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
for
i
in
range
(
self
.
n_slices
):
for
i
in
range
(
self
.
n_slices
):
if
(
lora_a_i
:
=
lora_a
[
i
])
is
not
None
:
if
(
lora_a_i
:
=
lora_a
[
i
])
is
not
None
:
self
.
lora_a_stacked
[
i
][
self
.
lora_a_stacked
[
i
][
index
,
0
,
:
lora_a_i
.
shape
[
1
],
:
lora_a_i
.
shape
[
0
]].
copy_
(
index
,
0
,
:
lora_a_i
.
shape
[
0
],
:
lora_a_i
.
shape
[
1
]].
copy_
(
lora_a_i
.
T
,
non_blocking
=
True
)
lora_a_i
,
non_blocking
=
True
)
if
(
lora_b_i
:
=
lora_b
[
i
])
is
not
None
:
if
(
lora_b_i
:
=
lora_b
[
i
])
is
not
None
:
self
.
lora_b_stacked
[
i
][
self
.
lora_b_stacked
[
i
][
index
,
0
,
:
lora_b_i
.
shape
[
1
],
:
lora_b_i
.
shape
[
0
]].
copy_
(
index
,
0
,
:
lora_b_i
.
shape
[
0
],
:
lora_b_i
.
shape
[
1
]].
copy_
(
lora_b_i
.
T
,
non_blocking
=
True
)
lora_b_i
,
non_blocking
=
True
)
if
lora_bias
is
not
None
:
if
lora_bias
is
not
None
:
self
.
lora_bias_stacked
=
cast
(
tuple
[
torch
.
Tensor
,
...],
self
.
lora_bias_stacked
=
cast
(
tuple
[
torch
.
Tensor
,
...],
...
@@ -299,7 +298,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -299,7 +298,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if
(
lora_bias_i
:
=
lora_bias
[
i
])
is
not
None
:
if
(
lora_bias_i
:
=
lora_bias
[
i
])
is
not
None
:
self
.
lora_bias_stacked
[
i
][
index
,
self
.
lora_bias_stacked
[
i
][
index
,
0
,
:
lora_bias_i
.
shape
[
0
]].
copy_
(
0
,
:
lora_bias_i
.
shape
[
0
]].
copy_
(
lora_bias_i
.
T
,
lora_bias_i
,
non_blocking
=
True
)
non_blocking
=
True
)
@
classmethod
@
classmethod
...
@@ -345,18 +344,18 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -345,18 +344,18 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
q_shard_id
=
tp_rank
self
.
q_shard_id
=
tp_rank
self
.
kv_shard_id
=
tp_rank
//
self
.
base_layer
.
num_kv_head_replicas
self
.
kv_shard_id
=
tp_rank
//
self
.
base_layer
.
num_kv_head_replicas
lora_b_q
=
lora_b
[
:,
self
.
q_proj_shard_size
*
lora_b_q
=
lora_b
[
self
.
q_proj_shard_size
*
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
(
self
.
q_shard_id
+
1
)]
(
self
.
q_shard_id
+
1
)
,
:
]
k_offset
=
self
.
q_proj_total_size
k_offset
=
self
.
q_proj_total_size
lora_b_k
=
lora_b
[
:,
k_offset
+
lora_b_k
=
lora_b
[
k_offset
+
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
k_offset
+
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
k_offset
+
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)]
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)
,
:
]
v_offset
=
k_offset
+
self
.
kv_proj_total_size
v_offset
=
k_offset
+
self
.
kv_proj_total_size
lora_b_v
=
lora_b
[
:,
v_offset
+
lora_b_v
=
lora_b
[
v_offset
+
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
v_offset
+
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
v_offset
+
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)]
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)
,
:
]
lora_b
=
torch
.
cat
([
lora_b_q
,
lora_b_k
,
lora_b_v
],
dim
=
1
)
lora_b
=
torch
.
cat
([
lora_b_q
,
lora_b_k
,
lora_b_v
],
dim
=
0
)
return
lora_b
return
lora_b
def
slice_bias
(
self
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
slice_bias
(
self
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -465,7 +464,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
...
@@ -465,7 +464,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
lora_a_stacked
[
0
].
shape
[
2
]
shard_size
=
self
.
lora_a_stacked
[
0
].
shape
[
2
]
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
lora_a
=
lora_a
[
:,
start_idx
:
start_idx
+
shard_size
]
lora_a
=
lora_a
[
start_idx
:
start_idx
+
shard_size
,
:
]
return
lora_a
return
lora_a
def
apply
(
self
,
def
apply
(
self
,
...
@@ -508,10 +507,10 @@ class MergedColumnParallelLinearWithShardedLoRA(
...
@@ -508,10 +507,10 @@ class MergedColumnParallelLinearWithShardedLoRA(
output_shard_size
=
self
.
lora_a_stacked
[
0
].
shape
[
2
]
output_shard_size
=
self
.
lora_a_stacked
[
0
].
shape
[
2
]
output_start_idx
=
self
.
tp_rank
*
output_shard_size
output_start_idx
=
self
.
tp_rank
*
output_shard_size
lora_a
=
[
lora_a
=
[
lora_a
[
0
][
:,
output_start_idx
:
output_start_idx
+
lora_a
[
0
][
output_start_idx
:
output_start_idx
+
output_shard_size
]
if
lora_a
[
0
]
is
not
None
else
None
,
output_shard_size
,
:
]
if
lora_a
[
0
]
is
not
None
else
None
,
lora_a
[
1
][
:,
output_start_idx
:
output_start_idx
+
lora_a
[
1
][
output_start_idx
:
output_start_idx
+
output_shard_size
]
if
lora_a
[
1
]
is
not
None
else
None
,
output_shard_size
,
:
]
if
lora_a
[
1
]
is
not
None
else
None
,
]
]
return
lora_a
return
lora_a
...
@@ -551,7 +550,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
...
@@ -551,7 +550,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
lora_a_stacked
[
0
].
shape
[
2
]
shard_size
=
self
.
lora_a_stacked
[
0
].
shape
[
2
]
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
lora_a
=
lora_a
[
:,
start_idx
:
start_idx
+
shard_size
]
lora_a
=
lora_a
[
start_idx
:
start_idx
+
shard_size
,
:
]
return
lora_a
return
lora_a
def
apply
(
self
,
def
apply
(
self
,
...
@@ -589,12 +588,12 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
...
@@ -589,12 +588,12 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
shard_size
=
[
self
.
lora_a_stacked
[
i
].
shape
[
2
]
for
i
in
range
(
3
)]
shard_size
=
[
self
.
lora_a_stacked
[
i
].
shape
[
2
]
for
i
in
range
(
3
)]
start_idx
=
[
self
.
tp_rank
*
shard_size
[
i
]
for
i
in
range
(
3
)]
start_idx
=
[
self
.
tp_rank
*
shard_size
[
i
]
for
i
in
range
(
3
)]
lora_a
=
[
lora_a
=
[
lora_a
[
0
][
:,
start_idx
[
0
]:
start_idx
[
0
]
+
lora_a
[
0
][
start_idx
[
0
]:
start_idx
[
0
]
+
shard_size
[
0
]]
if
lora_a
[
0
]
is
not
None
else
None
,
shard_size
[
0
]
,
:
]
if
lora_a
[
0
]
is
not
None
else
None
,
lora_a
[
1
][
:,
start_idx
[
1
]:
start_idx
[
1
]
+
lora_a
[
1
][
start_idx
[
1
]:
start_idx
[
1
]
+
shard_size
[
1
]]
if
lora_a
[
1
]
is
not
None
else
None
,
shard_size
[
1
]
,
:
]
if
lora_a
[
1
]
is
not
None
else
None
,
lora_a
[
2
][
:,
start_idx
[
2
]:
start_idx
[
2
]
+
lora_a
[
2
][
start_idx
[
2
]:
start_idx
[
2
]
+
shard_size
[
2
]]
if
lora_a
[
2
]
is
not
None
else
None
,
shard_size
[
2
]
,
:
]
if
lora_a
[
2
]
is
not
None
else
None
,
]
]
return
lora_a
return
lora_a
...
...
vllm/lora/layers/logits_processor.py
View file @
273690a5
...
@@ -140,11 +140,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
...
@@ -140,11 +140,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
):
):
self
.
reset_lora
(
index
)
self
.
reset_lora
(
index
)
self
.
lora_a_stacked
[
index
,
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
0
,
:
lora_a
.
shape
[
0
],
:
lora_a
.
shape
[
1
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
lora_a
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
index
,
self
.
lora_b_stacked
[
index
,
0
,
:
lora_b
.
shape
[
1
],
:
lora_b
.
shape
[
0
]].
copy_
(
0
,
:
lora_b
.
shape
[
0
],
:
lora_b
.
shape
[
1
]].
copy_
(
lora_b
.
T
,
non_blocking
=
True
)
lora_b
,
non_blocking
=
True
)
if
embeddings_tensor
is
not
None
:
if
embeddings_tensor
is
not
None
:
self
.
embeddings_tensors
[
self
.
embeddings_tensors
[
index
,
index
,
...
...
vllm/lora/layers/row_parallel_linear.py
View file @
273690a5
...
@@ -39,7 +39,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
...
@@ -39,7 +39,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
shard_size
=
self
.
input_size
shard_size
=
self
.
input_size
start_idx
=
self
.
tp_rank
*
shard_size
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
lora_a
=
lora_a
[
start_idx
:
end_idx
,
:
]
lora_a
=
lora_a
[
:,
start_idx
:
end_idx
]
return
lora_a
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -122,7 +122,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
...
@@ -122,7 +122,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
shard_size
=
self
.
lora_b_stacked
[
0
].
shape
[
2
]
shard_size
=
self
.
lora_b_stacked
[
0
].
shape
[
2
]
start_idx
=
self
.
tp_rank
*
shard_size
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
lora_b
=
lora_b
[
:,
start_idx
:
end_idx
]
lora_b
=
lora_b
[
start_idx
:
end_idx
,:
]
return
lora_b
return
lora_b
def
slice_bias
(
self
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
slice_bias
(
self
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/lora/layers/vocal_parallel_embedding.py
View file @
273690a5
...
@@ -95,11 +95,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -95,11 +95,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
self
.
reset_lora
(
index
)
self
.
reset_lora
(
index
)
self
.
lora_a_stacked
[
index
,
:
lora_a
.
shape
[
0
],
:
lora_a
.
shape
[
1
]].
copy_
(
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
lora_a
,
non_blocking
=
True
)
# so we need transpose here
self
.
lora_a_stacked
[
index
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
index
,
self
.
lora_b_stacked
[
index
,
0
,
:
lora_b
.
shape
[
1
],
:
lora_b
.
shape
[
0
]].
copy_
(
0
,
:
lora_b
.
shape
[
0
],
:
lora_b
.
shape
[
1
]].
copy_
(
lora_b
.
T
,
non_blocking
=
True
)
lora_b
,
non_blocking
=
True
)
if
embeddings_tensor
is
not
None
:
if
embeddings_tensor
is
not
None
:
self
.
embeddings_tensors
[
self
.
embeddings_tensors
[
index
,
index
,
...
...
vllm/lora/lora_weights.py
View file @
273690a5
...
@@ -86,11 +86,11 @@ class LoRALayerWeights:
...
@@ -86,11 +86,11 @@ class LoRALayerWeights:
embeddings_tensor_dim
:
Optional
[
int
]
=
None
,
embeddings_tensor_dim
:
Optional
[
int
]
=
None
,
bias_enabled
:
Optional
[
bool
]
=
False
)
->
"LoRALayerWeights"
:
bias_enabled
:
Optional
[
bool
]
=
False
)
->
"LoRALayerWeights"
:
pin_memory
=
str
(
device
)
==
"cpu"
and
is_pin_memory_available
()
pin_memory
=
str
(
device
)
==
"cpu"
and
is_pin_memory_available
()
lora_a
=
torch
.
zeros
([
input_dim
,
rank
],
lora_a
=
torch
.
zeros
([
rank
,
input_dim
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
lora_b
=
torch
.
zeros
([
rank
,
output_dim
],
lora_b
=
torch
.
zeros
([
output_dim
,
rank
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
...
...
vllm/lora/models.py
View file @
273690a5
...
@@ -152,30 +152,29 @@ class LoRAModel:
...
@@ -152,30 +152,29 @@ class LoRAModel:
module_name
,
peft_helper
,
lora_embeddings_tensor
)
module_name
,
peft_helper
,
lora_embeddings_tensor
)
if
is_bias
:
if
is_bias
:
loras
[
module_name
].
bias
=
tensor
.
to
(
device
=
device
,
loras
[
module_name
].
bias
=
tensor
.
to
(
device
=
device
,
dtype
=
dtype
)
dtype
=
dtype
).
t
()
bias
=
tensor
.
to
(
device
=
device
,
dtype
=
dtype
)
bias
=
tensor
.
to
(
device
=
device
,
dtype
=
dtype
).
t
()
if
pin_memory
:
if
pin_memory
:
bias
=
bias
.
pin_memory
()
bias
=
bias
.
pin_memory
()
loras
[
module_name
].
bias
=
bias
loras
[
module_name
].
bias
=
bias
elif
is_lora_a
:
elif
is_lora_a
:
loras
[
module_name
].
lora_a
=
tensor
.
to
(
device
=
device
,
loras
[
module_name
].
lora_a
=
tensor
.
to
(
device
=
device
,
dtype
=
dtype
)
.
t
()
dtype
=
dtype
)
if
pin_memory
:
if
pin_memory
:
loras
[
module_name
].
lora_a
=
loras
[
loras
[
module_name
].
lora_a
=
loras
[
module_name
].
lora_a
.
pin_memory
()
module_name
].
lora_a
.
pin_memory
()
else
:
else
:
loras
[
module_name
].
lora_b
=
tensor
.
to
(
device
=
device
,
loras
[
module_name
].
lora_b
=
tensor
.
to
(
device
=
device
,
dtype
=
dtype
)
.
t
()
dtype
=
dtype
)
assert
embedding_padding_modules
is
not
None
assert
embedding_padding_modules
is
not
None
if
any
(
name
in
module_name
if
any
(
name
in
module_name
for
name
in
embedding_padding_modules
for
name
in
embedding_padding_modules
)
and
target_embedding_padding
is
not
None
:
)
and
target_embedding_padding
is
not
None
:
lora_b
=
loras
[
module_name
].
lora_b
lora_b
=
loras
[
module_name
].
lora_b
assert
target_embedding_padding
>=
lora_b
.
shape
[
1
]
assert
target_embedding_padding
>=
lora_b
.
shape
[
0
]
addition
=
target_embedding_padding
-
lora_b
.
shape
[
1
]
addition
=
target_embedding_padding
-
lora_b
.
shape
[
0
]
loras
[
module_name
].
lora_b
=
torch
.
nn
.
functional
.
pad
(
loras
[
module_name
].
lora_b
=
torch
.
nn
.
functional
.
pad
(
lora_b
,
(
0
,
addition
))
lora_b
,
(
0
,
0
,
0
,
addition
))
if
pin_memory
:
if
pin_memory
:
loras
[
module_name
].
lora_b
=
loras
[
loras
[
module_name
].
lora_b
=
loras
[
module_name
].
lora_b
.
pin_memory
()
module_name
].
lora_b
.
pin_memory
()
...
@@ -585,7 +584,6 @@ class LoRAModelManager:
...
@@ -585,7 +584,6 @@ class LoRAModelManager:
"cpu"
,
"cpu"
,
bias_enabled
=
bias_enabled
,
bias_enabled
=
bias_enabled
,
)
)
lora
.
optimize
()
else
:
else
:
parts
=
module_name
.
split
(
"."
)
parts
=
module_name
.
split
(
"."
)
replacements
=
self
.
packed_modules_mapping
[
parts
[
-
1
]]
replacements
=
self
.
packed_modules_mapping
[
parts
[
-
1
]]
...
@@ -600,7 +598,6 @@ class LoRAModelManager:
...
@@ -600,7 +598,6 @@ class LoRAModelManager:
"cpu"
,
"cpu"
,
bias_enabled
=
bias_enabled
,
bias_enabled
=
bias_enabled
,
)
)
lora
.
optimize
()
subloras
.
append
(
lora
)
subloras
.
append
(
lora
)
lora
=
PackedLoRALayerWeights
.
pack
(
subloras
)
lora
=
PackedLoRALayerWeights
.
pack
(
subloras
)
model
.
loras
[
module_name
]
=
lora
model
.
loras
[
module_name
]
=
lora
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment