Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
11dd6ebb
Unverified
Commit
11dd6ebb
authored
Apr 10, 2024
by
Jee Li
Committed by
GitHub
Apr 09, 2024
Browse files
[Misc] Avoid loading incorrect LoRA config (#3777)
parent
6c0b0451
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
2 deletions
+66
-2
tests/lora/test_lora_checkpoints.py
tests/lora/test_lora_checkpoints.py
+40
-0
vllm/lora/models.py
vllm/lora/models.py
+15
-2
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+11
-0
No files found.
tests/lora/test_lora_checkpoints.py
0 → 100644
View file @
11dd6ebb
import
pytest
from
vllm.lora.models
import
LoRAModel
from
vllm.model_executor.models.baichuan
import
BaiChuanBaseForCausalLM
@
pytest
.
mark
.
parametrize
(
"lora_name"
,
[
"baichuan7B"
,
"chatglm3-6b"
])
def
test_load_checkpoints
(
lora_name
,
chatglm3_lora_files
,
baichuan_lora_files
):
supported_lora_modules
=
BaiChuanBaseForCausalLM
.
supported_lora_modules
packed_modules_mapping
=
BaiChuanBaseForCausalLM
.
packed_modules_mapping
embedding_modules
=
BaiChuanBaseForCausalLM
.
embedding_modules
embed_padding_modules
=
BaiChuanBaseForCausalLM
.
embedding_padding_modules
expected_lora_modules
=
[]
for
module
in
supported_lora_modules
:
if
module
in
packed_modules_mapping
:
expected_lora_modules
.
extend
(
packed_modules_mapping
[
module
])
else
:
expected_lora_modules
.
append
(
module
)
if
lora_name
==
"baichuan7B"
:
# For the baichuan7B model, load it's LoRA,
# and the test should pass.
LoRAModel
.
from_local_checkpoint
(
baichuan_lora_files
,
expected_lora_modules
,
lora_model_id
=
1
,
device
=
"cpu"
,
embedding_modules
=
embedding_modules
,
embedding_padding_modules
=
embed_padding_modules
)
else
:
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
expected_error
=
"Please verify that the loaded LoRA module is correct"
# noqa: E501
with
pytest
.
raises
(
ValueError
,
match
=
expected_error
):
LoRAModel
.
from_local_checkpoint
(
chatglm3_lora_files
,
expected_lora_modules
,
lora_model_id
=
1
,
device
=
"cpu"
,
embedding_modules
=
embedding_modules
,
embedding_padding_modules
=
embed_padding_modules
)
vllm/lora/models.py
View file @
11dd6ebb
...
@@ -191,6 +191,7 @@ class LoRAModel:
...
@@ -191,6 +191,7 @@ class LoRAModel:
def
from_local_checkpoint
(
def
from_local_checkpoint
(
cls
,
cls
,
lora_dir
:
str
,
lora_dir
:
str
,
expected_lora_modules
:
List
[
str
],
lora_model_id
:
Optional
[
int
]
=
None
,
lora_model_id
:
Optional
[
int
]
=
None
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
...
@@ -206,6 +207,20 @@ class LoRAModel:
...
@@ -206,6 +207,20 @@ class LoRAModel:
lora_dir
,
"new_embeddings.safetensors"
)
lora_dir
,
"new_embeddings.safetensors"
)
new_embeddings_bin_file_path
=
os
.
path
.
join
(
lora_dir
,
new_embeddings_bin_file_path
=
os
.
path
.
join
(
lora_dir
,
"new_embeddings.bin"
)
"new_embeddings.bin"
)
with
open
(
lora_config_path
)
as
f
:
config
=
json
.
load
(
f
)
target_modules
=
config
[
"target_modules"
]
unexpected_modules
=
[]
for
module
in
target_modules
:
if
module
not
in
expected_lora_modules
:
unexpected_modules
.
append
(
module
)
# loaded lora's target modules must be a subset of expected_lora_modules
if
unexpected_modules
:
raise
ValueError
(
f
"While loading
{
lora_dir
}
, expected"
f
" target modules in
{
expected_lora_modules
}
"
f
" but received
{
unexpected_modules
}
."
f
" Please verify that the loaded LoRA module is correct"
)
if
os
.
path
.
isfile
(
lora_tensor_path
):
if
os
.
path
.
isfile
(
lora_tensor_path
):
tensors
=
safetensors
.
torch
.
load_file
(
lora_tensor_path
)
tensors
=
safetensors
.
torch
.
load_file
(
lora_tensor_path
)
elif
os
.
path
.
isfile
(
lora_bin_file_path
):
elif
os
.
path
.
isfile
(
lora_bin_file_path
):
...
@@ -220,8 +235,6 @@ class LoRAModel:
...
@@ -220,8 +235,6 @@ class LoRAModel:
elif
os
.
path
.
isfile
(
new_embeddings_bin_file_path
):
elif
os
.
path
.
isfile
(
new_embeddings_bin_file_path
):
embeddings
=
torch
.
load
(
new_embeddings_bin_file_path
)
embeddings
=
torch
.
load
(
new_embeddings_bin_file_path
)
with
open
(
lora_config_path
)
as
f
:
config
=
json
.
load
(
f
)
rank
=
config
[
"r"
]
rank
=
config
[
"r"
]
lora_alpha
=
config
[
"lora_alpha"
]
lora_alpha
=
config
[
"lora_alpha"
]
return
cls
.
from_lora_tensors
(
return
cls
.
from_lora_tensors
(
...
...
vllm/lora/worker_manager.py
View file @
11dd6ebb
...
@@ -136,8 +136,19 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
...
@@ -136,8 +136,19 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def
_load_lora
(
self
,
lora_request
:
LoRARequest
)
->
LoRAModel
:
def
_load_lora
(
self
,
lora_request
:
LoRARequest
)
->
LoRAModel
:
try
:
try
:
model
=
self
.
_lora_manager
.
model
supported_lora_modules
=
model
.
supported_lora_modules
packed_modules_mapping
=
model
.
packed_modules_mapping
expected_lora_modules
=
[]
for
module
in
supported_lora_modules
:
if
module
in
packed_modules_mapping
:
expected_lora_modules
.
extend
(
packed_modules_mapping
[
module
])
else
:
expected_lora_modules
.
append
(
module
)
lora
=
self
.
_lora_model_cls
.
from_local_checkpoint
(
lora
=
self
.
_lora_model_cls
.
from_local_checkpoint
(
lora_request
.
lora_local_path
,
lora_request
.
lora_local_path
,
expected_lora_modules
,
lora_model_id
=
lora_request
.
lora_int_id
,
lora_model_id
=
lora_request
.
lora_int_id
,
device
=
"cpu"
,
device
=
"cpu"
,
dtype
=
self
.
lora_config
.
lora_dtype
,
dtype
=
self
.
lora_config
.
lora_dtype
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment