Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d17c8477
Unverified
Commit
d17c8477
authored
Apr 19, 2024
by
Jee Li
Committed by
GitHub
Apr 19, 2024
Browse files
[Bugfix] Fix LoRA loading check (#4138)
Co-authored-by:
simon-mo
<
simon.mo@hey.com
>
parent
a134ef6f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
3 deletions
+29
-3
tests/lora/conftest.py
tests/lora/conftest.py
+6
-0
tests/lora/test_lora_checkpoints.py
tests/lora/test_lora_checkpoints.py
+20
-2
vllm/lora/models.py
vllm/lora/models.py
+3
-1
No files found.
tests/lora/conftest.py
View file @
d17c8477
...
...
@@ -143,6 +143,12 @@ def baichuan_lora_files():
return
snapshot_download
(
repo_id
=
"jeeejeee/baichuan7b-text2sql-spider"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
baichuan_zero_lora_files
():
# all the lora_B weights are initialized to zero.
return
snapshot_download
(
repo_id
=
"jeeejeee/baichuan7b-zero-init"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
tinyllama_lora_files
():
return
snapshot_download
(
repo_id
=
"jashing/tinyllama-colorist-lora"
)
...
...
tests/lora/test_lora_checkpoints.py
View file @
d17c8477
...
...
@@ -3,9 +3,16 @@ import pytest
from
vllm.lora.models
import
LoRAModel
from
vllm.model_executor.models.baichuan
import
BaiChuanBaseForCausalLM
lora_lst
=
[
"baichuan7B"
,
"baichuan7B-zero"
,
"chatglm3-6b"
]
@
pytest
.
mark
.
parametrize
(
"lora_name"
,
[
"baichuan7B"
,
"chatglm3-6b"
])
def
test_load_checkpoints
(
lora_name
,
chatglm3_lora_files
,
baichuan_lora_files
):
@
pytest
.
mark
.
parametrize
(
"lora_name"
,
lora_lst
)
def
test_load_checkpoints
(
lora_name
,
baichuan_lora_files
,
baichuan_zero_lora_files
,
chatglm3_lora_files
,
):
supported_lora_modules
=
BaiChuanBaseForCausalLM
.
supported_lora_modules
packed_modules_mapping
=
BaiChuanBaseForCausalLM
.
packed_modules_mapping
embedding_modules
=
BaiChuanBaseForCausalLM
.
embedding_modules
...
...
@@ -26,6 +33,17 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):
device
=
"cpu"
,
embedding_modules
=
embedding_modules
,
embedding_padding_modules
=
embed_padding_modules
)
elif
lora_name
==
"baichuan7B-zero"
:
#Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and
# the test should pass.
LoRAModel
.
from_local_checkpoint
(
baichuan_zero_lora_files
,
expected_lora_modules
,
lora_model_id
=
1
,
device
=
"cpu"
,
embedding_modules
=
embedding_modules
,
embedding_padding_modules
=
embed_padding_modules
)
else
:
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
...
...
vllm/lora/models.py
View file @
d17c8477
...
...
@@ -212,7 +212,9 @@ class LoRAModel:
target_modules
=
config
[
"target_modules"
]
unexpected_modules
=
[]
for
module
in
target_modules
:
if
module
not
in
expected_lora_modules
:
# Compatible with more modules, such as:layers.11.self_attn.k_proj
part_name
=
module
.
split
(
"."
)[
-
1
]
if
part_name
not
in
expected_lora_modules
:
unexpected_modules
.
append
(
module
)
# loaded lora's target modules must be a subset of expected_lora_modules
if
unexpected_modules
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment