Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99d7cabd
Unverified
Commit
99d7cabd
authored
Aug 03, 2024
by
Jee Jee Li
Committed by
GitHub
Aug 02, 2024
Browse files
[LoRA] ReplicatedLinear support LoRA (#7081)
parent
fb2c1c86
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
199 additions
and
0 deletions
+199
-0
tests/lora/test_layers.py
tests/lora/test_layers.py
+103
-0
vllm/lora/layers.py
vllm/lora/layers.py
+94
-0
vllm/lora/utils.py
vllm/lora/utils.py
+2
-0
No files found.
tests/lora/test_layers.py
View file @
99d7cabd
...
...
@@ -22,6 +22,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA
,
MergedQKVParallelLinearWithLora
,
QKVParallelLinearWithLora
,
ReplicatedLinearWithLoRA
,
RowParallelLinearWithLoRA
,
VocabParallelEmbeddingWithLoRA
)
# yapf: enable
...
...
@@ -31,6 +32,7 @@ from vllm.lora.punica import PunicaWrapper
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
...
...
@@ -545,6 +547,107 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
atol
=
atol
)
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
def
test_linear_replicated
(
dist_init
,
num_loras
,
device
,
stage
)
->
None
:
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
max_loras
=
8
lora_config
=
LoRAConfig
(
max_loras
=
max_loras
,
max_lora_rank
=
8
,
lora_dtype
=
torch
.
float16
)
def
create_random_linear_replicated_layer
():
linear
=
ReplicatedLinear
(
4096
,
4096
,
bias
=
False
,
params_dtype
=
torch
.
float16
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
lora_linear
=
ReplicatedLinearWithLoRA
(
linear
)
lora_linear
.
create_lora_weights
(
max_loras
,
lora_config
)
return
linear
,
lora_linear
for
i
in
range
(
10
):
set_random_seed
(
i
)
id_to_index
=
get_random_id_to_index
(
num_loras
,
max_loras
)
linear
,
lora_linear
=
create_random_linear_replicated_layer
()
lora_linear
.
set_mapping
(
punica_wrapper
)
lora_dict
,
_
=
populate_loras
(
id_to_index
,
layer
=
lora_linear
,
layer_weights
=
linear
.
weight
,
)
inputs
,
index_mapping
,
prompt_mapping
=
create_random_inputs
(
active_lora_ids
=
list
(
lora_dict
.
keys
()),
num_inputs
=
32
*
num_loras
,
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float16
,
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
punica_wrapper
.
update_metadata
(
lora_mapping
,
id_to_index
,
max_loras
,
512
,
lora_config
.
lora_extra_vocab_size
,
)
lora_result
=
lora_linear
(
torch
.
cat
(
inputs
))[
0
]
expected_results
:
List
[
torch
.
Tensor
]
=
[]
for
input_
,
lora_id
in
zip
(
inputs
,
prompt_mapping
):
lora
=
lora_dict
[
lora_id
]
result
=
linear
(
input_
)[
0
]
result
+=
input_
@
lora
.
lora_a
@
lora
.
lora_b
*
lora
.
scaling
expected_results
.
append
(
result
)
expected_result
=
torch
.
cat
(
expected_results
)
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
allclose
(
lora_result
,
expected_result
,
rtol
=
rtol
,
atol
=
atol
)
# Check that resetting the lora weights succeeds
for
slot_idx
in
range
(
max_loras
):
lora_linear
.
reset_lora
(
slot_idx
)
inputs
,
index_mapping
,
prompt_mapping
=
create_random_inputs
(
active_lora_ids
=
[
0
],
num_inputs
=
32
*
num_loras
,
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float16
,
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
punica_wrapper
.
update_metadata
(
lora_mapping
,
id_to_index
,
max_loras
,
512
,
lora_config
.
lora_extra_vocab_size
)
lora_result
=
lora_linear
(
torch
.
cat
(
inputs
))[
0
]
expected_result
=
linear
(
torch
.
cat
(
inputs
))[
0
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
allclose
(
lora_result
,
expected_result
,
rtol
=
rtol
,
atol
=
atol
)
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"orientation"
,
[
"row"
,
"column"
])
...
...
vllm/lora/layers.py
View file @
99d7cabd
...
...
@@ -21,6 +21,7 @@ from vllm.lora.punica import PunicaWrapper
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
(
...
...
@@ -262,6 +263,99 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
return
type
(
source_layer
)
is
VocabParallelEmbedding
class
ReplicatedLinearWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
ReplicatedLinear
)
->
None
:
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
input_size
=
self
.
base_layer
.
input_size
self
.
output_size
=
self
.
base_layer
.
output_size
self
.
device
=
_get_lora_device
(
self
.
base_layer
)
def
create_lora_weights
(
self
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
None
:
self
.
lora_config
=
lora_config
lora_a_output_size
=
lora_config
.
max_lora_rank
self
.
lora_a_stacked
=
torch
.
zeros
(
max_loras
,
1
,
lora_a_output_size
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
self
.
lora_b_stacked
=
torch
.
zeros
(
max_loras
,
1
,
self
.
output_size
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_b_stacked
[
index
]
=
0
def
set_lora
(
self
,
index
:
int
,
lora_a
:
torch
.
Tensor
,
lora_b
:
torch
.
Tensor
,
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
):
self
.
reset_lora
(
index
)
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
index
,
0
,
:
lora_b
.
shape
[
1
],
:
lora_b
.
shape
[
0
]].
copy_
(
lora_b
.
T
,
non_blocking
=
True
)
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
self
.
punica_wrapper
.
add_lora
(
output
,
x
,
self
.
lora_a_stacked
,
self
.
lora_b_stacked
,
1.0
)
return
output
def
forward
(
self
,
input_
):
"""Forward of ReplicatedLinearWithLoRA
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias
=
(
self
.
base_layer
.
bias
if
not
self
.
base_layer
.
skip_bias_add
else
None
)
# Matrix multiply.
output
=
self
.
apply
(
input_
,
bias
)
output_bias
=
(
self
.
base_layer
.
bias
if
self
.
base_layer
.
skip_bias_add
else
None
)
return
output
,
output_bias
@
classmethod
@
_not_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
],
)
->
bool
:
return
type
(
source_layer
)
is
ReplicatedLinear
class
ColumnParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
"""
LoRA on top of ColumnParallelLinear layer.
...
...
vllm/lora/utils.py
View file @
99d7cabd
...
...
@@ -23,6 +23,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA
,
MergedQKVParallelLinearWithLora
,
QKVParallelLinearWithLora
,
ReplicatedLinearWithLoRA
,
RowParallelLinearWithLoRA
,
VocabParallelEmbeddingWithLoRA
)
# yapf: enable
...
...
@@ -38,6 +39,7 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
QKVParallelLinearWithLora
,
MergedQKVParallelLinearWithLora
,
RowParallelLinearWithLoRA
,
ReplicatedLinearWithLoRA
,
LogitsProcessorWithLoRA
,
ColumnParallelLinearWithShardedLoRA
,
QKVParallelLinearWithShardedLora
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment