Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
36ea7907
Unverified
Commit
36ea7907
authored
Oct 11, 2024
by
Jee Jee Li
Committed by
GitHub
Oct 11, 2024
Browse files
[Misc][LoRA] Support loading LoRA weights for target_modules in reg format (#9275)
parent
e808156f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
5 deletions
+59
-5
tests/lora/conftest.py
tests/lora/conftest.py
+5
-0
tests/lora/test_lora_checkpoints.py
tests/lora/test_lora_checkpoints.py
+15
-2
vllm/lora/models.py
vllm/lora/models.py
+5
-2
vllm/lora/utils.py
vllm/lora/utils.py
+34
-1
No files found.
tests/lora/conftest.py
View file @
36ea7907
...
...
@@ -199,6 +199,11 @@ def baichuan_zero_lora_files():
return
snapshot_download
(
repo_id
=
"jeeejeee/baichuan7b-zero-init"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
baichuan_regex_lora_files
():
return
snapshot_download
(
repo_id
=
"jeeejeee/baichuan-7b-lora-zero-regex"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
minicpmv_lora_files
():
return
snapshot_download
(
repo_id
=
"jeeejeee/minicpmv25-lora-pokemon"
)
...
...
tests/lora/test_lora_checkpoints.py
View file @
36ea7907
...
...
@@ -5,7 +5,9 @@ import pytest
from
vllm.lora.models
import
LoRAModel
from
vllm.model_executor.models.baichuan
import
BaiChuanBaseForCausalLM
lora_lst
=
[
"baichuan7B"
,
"baichuan7B-zero"
,
"chatglm3-6b"
]
lora_lst
=
[
"baichuan7B"
,
"baichuan7B-zero"
,
"baichuan7B-zero-regex"
,
"chatglm3-6b"
]
@
pytest
.
mark
.
parametrize
(
"lora_name"
,
lora_lst
)
...
...
@@ -13,6 +15,7 @@ def test_load_checkpoints(
lora_name
,
baichuan_lora_files
,
baichuan_zero_lora_files
,
baichuan_regex_lora_files
,
chatglm3_lora_files
,
):
supported_lora_modules
=
BaiChuanBaseForCausalLM
.
supported_lora_modules
...
...
@@ -36,7 +39,7 @@ def test_load_checkpoints(
embedding_modules
=
embedding_modules
,
embedding_padding_modules
=
embed_padding_modules
)
elif
lora_name
==
"baichuan7B-zero"
:
#Test that the target_modules contain prefix
#
Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and
# the test should pass.
LoRAModel
.
from_local_checkpoint
(
...
...
@@ -46,6 +49,16 @@ def test_load_checkpoints(
device
=
"cpu"
,
embedding_modules
=
embedding_modules
,
embedding_padding_modules
=
embed_padding_modules
)
elif
lora_name
==
"baichuan7B-zero-regex"
:
# Test that the `target_modules` in the form of regular expressions,
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
LoRAModel
.
from_local_checkpoint
(
baichuan_regex_lora_files
,
expected_lora_modules
,
lora_model_id
=
1
,
device
=
"cpu"
,
embedding_modules
=
embedding_modules
,
embedding_padding_modules
=
embed_padding_modules
)
else
:
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
...
...
vllm/lora/models.py
View file @
36ea7907
...
...
@@ -23,6 +23,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.punica
import
PunicaWrapper
from
vllm.lora.utils
import
(
from_layer
,
from_layer_logits_processor
,
is_regex_target_modules
,
parse_fine_tuned_lora_name
,
replace_submodule
)
from
vllm.model_executor.models
import
SupportsLoRA
,
supports_multimodal
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
...
...
@@ -233,6 +234,8 @@ class LoRAModel(AdapterModel):
# modules.
unexpected_modules
=
[]
target_modules
=
config
[
"target_modules"
]
if
not
isinstance
(
target_modules
,
list
):
target_modules
=
[
target_modules
]
for
module
in
target_modules
:
# Compatible with more modules,
# such as:layers.11.self_attn.k_proj
...
...
@@ -243,8 +246,8 @@ class LoRAModel(AdapterModel):
# expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if
unexpected_modules
:
print
(
unexpected_modules
,
"
modules
"
)
if
unexpected_modules
and
not
is_regex_target_modules
(
config
[
"target_modules"
],
expected_lora_
modules
)
:
raise
ValueError
(
f
"While loading
{
lora_dir
}
, expected"
f
" target modules in
{
expected_lora_modules
}
"
...
...
vllm/lora/utils.py
View file @
36ea7907
import
os
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Type
import
re
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
huggingface_hub
from
huggingface_hub.utils
import
(
EntryNotFoundError
,
HfHubHTTPError
,
...
...
@@ -113,6 +114,38 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
raise
ValueError
(
f
"
{
name
}
is unsupported LoRA weight"
)
def
is_regex_target_modules
(
load_modules
:
Union
[
str
,
List
[
str
]],
expected_lora_modules
:
List
[
str
])
->
bool
:
"""
PEFT supports passing `target_modules` in the form of regular expressions,
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
determine whether the suffix in the regular expression is present in the
`expected_lora_modules`.
"""
def
is_valid_regex
(
pattern
):
try
:
re
.
compile
(
pattern
)
return
True
except
re
.
error
:
return
False
def
is_subset
(
sub_list
,
full_list
):
return
set
(
sub_list
).
issubset
(
set
(
full_list
))
# Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`.
if
not
isinstance
(
load_modules
,
str
):
return
False
if
is_valid_regex
(
load_modules
):
match
=
re
.
search
(
r
"\((.*?)\)\$?$"
,
load_modules
)
if
match
:
suffix
=
match
.
group
(
1
).
split
(
"|"
)
return
is_subset
(
suffix
,
expected_lora_modules
)
return
False
def
get_adapter_absolute_path
(
lora_path
:
str
)
->
str
:
"""
Resolves the given lora_path to an absolute local path.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment