Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
285178b3
Unverified
Commit
285178b3
authored
Aug 22, 2025
by
Jee Jee Li
Committed by
GitHub
Aug 22, 2025
Browse files
[V0 Deprecation] Remove V0 LoRA test (#23418)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
88016c37
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
158 additions
and
116 deletions
+158
-116
tests/lora/conftest.py
tests/lora/conftest.py
+4
-27
tests/lora/test_add_lora.py
tests/lora/test_add_lora.py
+4
-7
tests/lora/test_llama_tp.py
tests/lora/test_llama_tp.py
+1
-4
tests/lora/test_lora_manager.py
tests/lora/test_lora_manager.py
+68
-62
tests/lora/test_mixtral.py
tests/lora/test_mixtral.py
+0
-1
tests/lora/test_worker.py
tests/lora/test_worker.py
+5
-15
tests/lora/utils.py
tests/lora/utils.py
+76
-0
No files found.
tests/lora/conftest.py
View file @
285178b3
...
...
@@ -3,15 +3,13 @@
import
tempfile
from
collections
import
OrderedDict
from
unittest.mock
import
MagicMock
,
patch
from
unittest.mock
import
MagicMock
import
pytest
import
torch
import
torch.nn
as
nn
from
huggingface_hub
import
snapshot_download
import
vllm
from
vllm.config
import
LoRAConfig
from
vllm.distributed
import
(
cleanup_dist_env_and_memory
,
init_distributed_environment
,
initialize_model_parallel
)
...
...
@@ -21,7 +19,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models.interfaces
import
SupportsLoRA
from
vllm.platforms
import
current_platform
...
...
@@ -104,6 +101,7 @@ def dummy_model() -> nn.Module:
]))
model
.
config
=
MagicMock
()
model
.
embedding_modules
=
{
"lm_head"
:
"lm_head"
}
model
.
unpadded_vocab_size
=
32000
return
model
...
...
@@ -137,6 +135,8 @@ def dummy_model_gate_up() -> nn.Module:
],
}
model
.
embedding_modules
=
{
"lm_head"
:
"lm_head"
}
model
.
unpadded_vocab_size
=
32000
return
model
...
...
@@ -221,29 +221,6 @@ def phi2_lora_files():
return
snapshot_download
(
repo_id
=
"isotr0py/phi-2-test-sql-lora"
)
@
pytest
.
fixture
def
llama_2_7b_engine_extra_embeddings
():
cleanup_dist_env_and_memory
(
shutdown_ray
=
True
)
get_model_old
=
get_model
def
get_model_patched
(
**
kwargs
):
kwargs
[
"vllm_config"
].
lora_config
=
LoRAConfig
(
max_loras
=
4
,
max_lora_rank
=
8
)
return
get_model_old
(
**
kwargs
)
with
patch
(
"vllm.worker.model_runner.get_model"
,
get_model_patched
):
engine
=
vllm
.
LLM
(
"meta-llama/Llama-2-7b-hf"
,
enable_lora
=
False
)
yield
engine
.
llm_engine
del
engine
cleanup_dist_env_and_memory
(
shutdown_ray
=
True
)
@
pytest
.
fixture
def
llama_2_7b_model_extra_embeddings
(
llama_2_7b_engine_extra_embeddings
):
yield
(
llama_2_7b_engine_extra_embeddings
.
model_executor
.
driver_worker
.
model_runner
.
model
)
@
pytest
.
fixture
def
reset_default_device
():
"""
...
...
tests/lora/test_add_lora.py
View file @
285178b3
...
...
@@ -5,7 +5,6 @@ import time
import
pytest
import
vllm.envs
as
env
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.entrypoints.openai.api_server
import
(
build_async_engine_client_from_engine_args
)
...
...
@@ -98,12 +97,10 @@ async def test_add_lora(chatglm3_lora_files):
# Run with warmup
add_lora_tasks
=
[
llm
.
add_lora
(
lr
)
for
lr
in
warmup_run_requests
]
add_lora_results
=
await
asyncio
.
gather
(
*
add_lora_tasks
)
if
env
.
VLLM_USE_V1
:
# Test that all all_lora calls are successful.
assert
all
(
add_lora_results
)
else
:
# No way to check V0 engine results as the calls just return None.
pass
# Test that all all_lora calls are successful.
assert
all
(
add_lora_results
)
time_with_add_lora
=
await
requests_processing_time
(
llm
,
warmup_run_requests
)
...
...
tests/lora/test_llama_tp.py
View file @
285178b3
...
...
@@ -113,8 +113,7 @@ def test_llama_lora(sql_lora_files):
enable_lora
=
True
,
# also test odd max_num_seqs
max_num_seqs
=
13
,
max_loras
=
4
,
enable_chunked_prefill
=
True
)
max_loras
=
4
)
generate_and_test
(
llm
,
sql_lora_files
)
...
...
@@ -128,7 +127,6 @@ def test_llama_lora_tp4(sql_lora_files):
max_num_seqs
=
16
,
max_loras
=
4
,
tensor_parallel_size
=
4
,
enable_chunked_prefill
=
True
,
)
generate_and_test
(
llm
,
sql_lora_files
)
...
...
@@ -144,7 +142,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
max_loras
=
4
,
tensor_parallel_size
=
4
,
fully_sharded_loras
=
True
,
enable_chunked_prefill
=
True
,
)
generate_and_test
(
llm
,
sql_lora_files
)
...
...
tests/lora/test_lora_manager.py
View file @
285178b3
...
...
@@ -21,6 +21,8 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager
)
from
vllm.platforms
import
current_platform
from
.utils
import
create_peft_lora
EMBEDDING_MODULES
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
...
...
@@ -35,17 +37,6 @@ DEVICES = ([
DEFAULT_DTYPE
=
torch
.
get_default_dtype
()
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Some tests depend on V0 internals. Since both V0 and V1 use the same
LoRAModelManager it is okay to just test V0.
"""
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
yield
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
def
test_from_lora_tensors
(
sql_lora_files
,
device
):
tensors
=
load_file
(
...
...
@@ -326,7 +317,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
max_loras
=
2
,
lora_dtype
=
DEFAULT_DTYPE
),
device
=
device
)
assert
all
(
x
is
None
for
x
in
manager
.
lora_index_to_id
)
# Add up to capacity
...
...
@@ -430,32 +420,40 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
def
test_lru_cache_worker_adapter_manager
(
llama_2_7b_model_extra_embeddings
,
sql_lora_files
,
device
):
def
test_lru_cache_worker_adapter_manager
(
dist_init
,
dummy_model
,
device
,
tmp_path
):
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
4
,
max_loras
=
4
,
lora_dtype
=
DEFAULT_DTYPE
)
dummy_lora_files
=
f
"
{
tmp_path
}
/lora_adapter"
os
.
makedirs
(
dummy_lora_files
,
exist_ok
=
True
)
create_peft_lora
(
dummy_model
,
save_dir
=
dummy_lora_files
,
target_modules
=
[
"layer1.dense1"
,
"dense2"
],
lora_dtype
=
DEFAULT_DTYPE
,
)
worker_adapter_manager
=
LRUCacheWorkerLoRAManager
(
4
,
2
,
llama_2_7b_model_extra_embeddings
.
unpadded_vocab_size
-
lora_config
.
lora_extra_vocab_size
,
lora_config
,
device
,
EMBEDDING_MODULES
,
EMBEDDING_PADDING_MODULES
)
worker_adapter_manager
.
create_lora_manager
(
llama_2_7b_model_extra_embeddings
)
4
,
2
,
dummy_model
.
unpadded_vocab_size
-
lora_config
.
lora_extra_vocab_size
,
lora_config
,
device
,
EMBEDDING_MODULES
,
EMBEDDING_PADDING_MODULES
)
worker_adapter_manager
.
create_lora_manager
(
dummy_model
)
mapping
=
LoRAMapping
([],
[])
worker_adapter_manager
.
set_active_adapters
([
LoRARequest
(
"1"
,
1
,
sql
_lora_files
),
LoRARequest
(
"2"
,
2
,
sql
_lora_files
)
LoRARequest
(
"1"
,
1
,
dummy
_lora_files
),
LoRARequest
(
"2"
,
2
,
dummy
_lora_files
)
],
mapping
)
assert
worker_adapter_manager
.
list_adapters
()
==
{
1
,
2
}
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
0
]
==
1
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
1
]
==
2
worker_adapter_manager
.
set_active_adapters
([
LoRARequest
(
"1"
,
1
,
sql
_lora_files
),
LoRARequest
(
"3"
,
3
,
sql
_lora_files
),
LoRARequest
(
"4"
,
4
,
sql
_lora_files
)
LoRARequest
(
"1"
,
1
,
dummy
_lora_files
),
LoRARequest
(
"3"
,
3
,
dummy
_lora_files
),
LoRARequest
(
"4"
,
4
,
dummy
_lora_files
)
],
mapping
)
assert
worker_adapter_manager
.
list_adapters
()
==
{
1
,
2
,
3
,
4
}
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
0
]
==
1
...
...
@@ -464,9 +462,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
3
]
==
4
worker_adapter_manager
.
set_active_adapters
([
LoRARequest
(
"1"
,
1
,
sql
_lora_files
),
LoRARequest
(
"2"
,
2
,
sql
_lora_files
),
LoRARequest
(
"5"
,
5
,
sql
_lora_files
)
LoRARequest
(
"1"
,
1
,
dummy
_lora_files
),
LoRARequest
(
"2"
,
2
,
dummy
_lora_files
),
LoRARequest
(
"5"
,
5
,
dummy
_lora_files
)
],
mapping
)
assert
worker_adapter_manager
.
list_adapters
()
==
{
1
,
2
,
4
,
5
}
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
0
]
==
1
...
...
@@ -475,9 +473,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
3
]
==
4
worker_adapter_manager
.
set_active_adapters
([
LoRARequest
(
"1"
,
1
,
sql
_lora_files
),
LoRARequest
(
"1"
,
1
,
sql
_lora_files
),
LoRARequest
(
"1"
,
1
,
sql
_lora_files
)
LoRARequest
(
"1"
,
1
,
dummy
_lora_files
),
LoRARequest
(
"1"
,
1
,
dummy
_lora_files
),
LoRARequest
(
"1"
,
1
,
dummy
_lora_files
)
],
mapping
)
assert
worker_adapter_manager
.
list_adapters
()
==
{
1
,
2
,
4
,
5
}
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
0
]
==
1
...
...
@@ -486,9 +484,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
3
]
==
4
worker_adapter_manager
.
set_active_adapters
([
LoRARequest
(
"6"
,
6
,
sql
_lora_files
),
LoRARequest
(
"7"
,
7
,
sql
_lora_files
),
LoRARequest
(
"8"
,
8
,
sql
_lora_files
)
LoRARequest
(
"6"
,
6
,
dummy
_lora_files
),
LoRARequest
(
"7"
,
7
,
dummy
_lora_files
),
LoRARequest
(
"8"
,
8
,
dummy
_lora_files
)
],
mapping
)
assert
worker_adapter_manager
.
list_adapters
()
==
{
1
,
6
,
7
,
8
}
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
0
]
==
1
...
...
@@ -499,11 +497,11 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
# Over capacity
with
pytest
.
raises
(
RuntimeError
):
worker_adapter_manager
.
set_active_adapters
([
LoRARequest
(
"10"
,
10
,
sql
_lora_files
),
LoRARequest
(
"11"
,
11
,
sql
_lora_files
),
LoRARequest
(
"12"
,
12
,
sql
_lora_files
),
LoRARequest
(
"13"
,
13
,
sql
_lora_files
),
LoRARequest
(
"14"
,
14
,
sql
_lora_files
)
LoRARequest
(
"10"
,
10
,
dummy
_lora_files
),
LoRARequest
(
"11"
,
11
,
dummy
_lora_files
),
LoRARequest
(
"12"
,
12
,
dummy
_lora_files
),
LoRARequest
(
"13"
,
13
,
dummy
_lora_files
),
LoRARequest
(
"14"
,
14
,
dummy
_lora_files
)
],
mapping
)
assert
worker_adapter_manager
.
device
==
device
...
...
@@ -512,33 +510,41 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
def
test_worker_adapter_manager
(
llama_2_7b_model_extra_embeddings
,
sql_lora_files
,
device
):
def
test_worker_adapter_manager
(
dist_init
,
dummy_model_gate_up
,
device
,
tmp_path
):
# Should remove every LoRA not specified in the request.
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
4
,
max_loras
=
4
,
lora_dtype
=
DEFAULT_DTYPE
)
worker_adapter_manager
=
WorkerLoRAManager
(
4
,
2
,
llama_2_7b_model_extra_embeddings
.
unpadded_vocab_size
-
4
,
2
,
dummy_model_gate_up
.
unpadded_vocab_size
-
lora_config
.
lora_extra_vocab_size
,
lora_config
,
device
,
EMBEDDING_MODULES
,
EMBEDDING_PADDING_MODULES
)
worker_adapter_manager
.
create_lora_manager
(
llama_2_7b_model_extra_embeddings
)
worker_adapter_manager
.
create_lora_manager
(
dummy_model_gate_up
)
dummy_lora_files
=
f
"
{
tmp_path
}
/lora_adapter"
os
.
makedirs
(
dummy_lora_files
,
exist_ok
=
True
)
create_peft_lora
(
dummy_model_gate_up
,
save_dir
=
dummy_lora_files
,
target_modules
=
[
"layer1.dense1"
,
"dense2"
],
lora_dtype
=
DEFAULT_DTYPE
,
)
mapping
=
LoRAMapping
([],
[])
worker_adapter_manager
.
set_active_adapters
([
LoRARequest
(
"1"
,
1
,
sql
_lora_files
),
LoRARequest
(
"2"
,
2
,
sql
_lora_files
)
LoRARequest
(
"1"
,
1
,
dummy
_lora_files
),
LoRARequest
(
"2"
,
2
,
dummy
_lora_files
)
],
mapping
)
assert
worker_adapter_manager
.
list_adapters
()
==
{
1
,
2
}
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
0
]
==
1
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
1
]
==
2
worker_adapter_manager
.
set_active_adapters
([
LoRARequest
(
"1"
,
1
,
sql
_lora_files
),
LoRARequest
(
"3"
,
3
,
sql
_lora_files
),
LoRARequest
(
"4"
,
4
,
sql
_lora_files
)
LoRARequest
(
"1"
,
1
,
dummy
_lora_files
),
LoRARequest
(
"3"
,
3
,
dummy
_lora_files
),
LoRARequest
(
"4"
,
4
,
dummy
_lora_files
)
],
mapping
)
assert
worker_adapter_manager
.
list_adapters
()
==
{
1
,
3
,
4
}
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
0
]
==
1
...
...
@@ -546,9 +552,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
2
]
==
4
worker_adapter_manager
.
set_active_adapters
([
LoRARequest
(
"1"
,
1
,
sql
_lora_files
),
LoRARequest
(
"2"
,
2
,
sql
_lora_files
),
LoRARequest
(
"5"
,
5
,
sql
_lora_files
)
LoRARequest
(
"1"
,
1
,
dummy
_lora_files
),
LoRARequest
(
"2"
,
2
,
dummy
_lora_files
),
LoRARequest
(
"5"
,
5
,
dummy
_lora_files
)
],
mapping
)
assert
worker_adapter_manager
.
list_adapters
()
==
{
1
,
2
,
5
}
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
0
]
==
1
...
...
@@ -556,9 +562,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
2
]
==
5
worker_adapter_manager
.
set_active_adapters
([
LoRARequest
(
"1"
,
1
,
sql
_lora_files
),
LoRARequest
(
"1"
,
1
,
sql
_lora_files
),
LoRARequest
(
"1"
,
1
,
sql
_lora_files
)
LoRARequest
(
"1"
,
1
,
dummy
_lora_files
),
LoRARequest
(
"1"
,
1
,
dummy
_lora_files
),
LoRARequest
(
"1"
,
1
,
dummy
_lora_files
)
],
mapping
)
assert
worker_adapter_manager
.
list_adapters
()
==
{
1
}
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
0
]
==
1
...
...
@@ -566,9 +572,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
2
]
is
None
worker_adapter_manager
.
set_active_adapters
([
LoRARequest
(
"6"
,
6
,
sql
_lora_files
),
LoRARequest
(
"7"
,
7
,
sql
_lora_files
),
LoRARequest
(
"8"
,
8
,
sql
_lora_files
)
LoRARequest
(
"6"
,
6
,
dummy
_lora_files
),
LoRARequest
(
"7"
,
7
,
dummy
_lora_files
),
LoRARequest
(
"8"
,
8
,
dummy
_lora_files
)
],
mapping
)
assert
worker_adapter_manager
.
list_adapters
()
==
{
6
,
7
,
8
}
assert
worker_adapter_manager
.
_adapter_manager
.
lora_index_to_id
[
0
]
==
8
...
...
@@ -578,11 +584,11 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
# Over capacity
with
pytest
.
raises
(
RuntimeError
):
worker_adapter_manager
.
set_active_adapters
([
LoRARequest
(
"10"
,
10
,
sql
_lora_files
),
LoRARequest
(
"11"
,
11
,
sql
_lora_files
),
LoRARequest
(
"12"
,
12
,
sql
_lora_files
),
LoRARequest
(
"13"
,
13
,
sql
_lora_files
),
LoRARequest
(
"14"
,
14
,
sql
_lora_files
)
LoRARequest
(
"10"
,
10
,
dummy
_lora_files
),
LoRARequest
(
"11"
,
11
,
dummy
_lora_files
),
LoRARequest
(
"12"
,
12
,
dummy
_lora_files
),
LoRARequest
(
"13"
,
13
,
dummy
_lora_files
),
LoRARequest
(
"14"
,
14
,
dummy
_lora_files
)
],
mapping
)
assert
worker_adapter_manager
.
device
==
device
...
...
tests/lora/test_mixtral.py
View file @
285178b3
...
...
@@ -50,7 +50,6 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
max_loras
=
4
,
distributed_executor_backend
=
"ray"
,
tensor_parallel_size
=
tp_size
,
enable_chunked_prefill
=
True
,
)
expected_lora_output
=
[
...
...
tests/lora/test_worker.py
View file @
285178b3
...
...
@@ -4,17 +4,14 @@
import
os
import
random
import
tempfile
from
typing
import
Union
from
unittest.mock
import
patch
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VllmConfig
)
from
vllm.lora.models
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.v1.worker.gpu_worker
import
Worker
as
V1Worker
from
vllm.worker.worker
import
Worker
from
vllm.v1.worker.gpu_worker
import
Worker
NUM_LORAS
=
16
...
...
@@ -22,18 +19,11 @@ NUM_LORAS = 16
@
patch
.
dict
(
os
.
environ
,
{
"RANK"
:
"0"
})
def
test_worker_apply_lora
(
sql_lora_files
):
def
set_active_loras
(
worker
:
Union
[
Worker
,
V1Worker
],
lora_requests
:
list
[
LoRARequest
]):
def
set_active_loras
(
worker
:
Worker
,
lora_requests
:
list
[
LoRARequest
]):
lora_mapping
=
LoRAMapping
([],
[])
if
isinstance
(
worker
,
Worker
):
# v0 case
worker
.
model_runner
.
set_active_loras
(
lora_requests
,
lora_mapping
)
else
:
# v1 case
worker
.
model_runner
.
lora_manager
.
set_active_adapters
(
lora_requests
,
lora_mapping
)
worker_cls
=
V1Worker
if
envs
.
VLLM_USE_V1
else
Worker
worker
.
model_runner
.
lora_manager
.
set_active_adapters
(
lora_requests
,
lora_mapping
)
vllm_config
=
VllmConfig
(
model_config
=
ModelConfig
(
...
...
@@ -62,7 +52,7 @@ def test_worker_apply_lora(sql_lora_files):
max_cpu_loras
=
NUM_LORAS
,
max_loras
=
NUM_LORAS
),
)
worker
=
w
orker
_cls
(
worker
=
W
orker
(
vllm_config
=
vllm_config
,
local_rank
=
0
,
rank
=
0
,
...
...
tests/lora/utils.py
View file @
285178b3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
os
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Union
import
torch
from
safetensors.torch
import
save_file
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
...
...
@@ -340,3 +343,76 @@ def generate_data_for_nslices(
seq_len_tensor
,
indices
,
)
def
create_peft_lora
(
model
:
torch
.
nn
.
Module
,
save_dir
:
str
,
target_modules
:
list
[
str
],
rank
:
int
=
8
,
alpha
:
int
=
16
,
dropout
:
float
=
0.1
,
lora_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
dict
[
str
,
torch
.
Tensor
]:
lora_weights
=
{}
adapter_config
=
{
"peft_type"
:
"LORA"
,
"auto_mapping"
:
None
,
"base_model_name_or_path"
:
"dummy_model"
,
"revision"
:
None
,
"task_type"
:
"CAUSAL_LM"
,
"inference_mode"
:
False
,
"r"
:
rank
,
"lora_alpha"
:
alpha
,
"lora_dropout"
:
dropout
,
"fan_in_fan_out"
:
False
,
"bias"
:
"none"
,
"modules_to_save"
:
None
,
"init_lora_weights"
:
True
,
"layers_to_transform"
:
None
,
"layers_pattern"
:
None
,
"target_modules"
:
target_modules
,
"exclude_modules"
:
None
,
"use_rslora"
:
False
,
"use_dora"
:
False
,
"loftq_config"
:
None
,
}
for
module_name
in
target_modules
:
module
=
model
for
attr
in
module_name
.
split
(
"."
):
module
=
getattr
(
module
,
attr
)
if
hasattr
(
module
,
"input_size"
)
and
hasattr
(
module
,
"output_size"
):
in_features
=
module
.
input_size
out_features
=
module
.
output_size
elif
hasattr
(
module
,
"embedding_dim"
)
and
hasattr
(
module
,
"num_embeddings"
):
# ParallelLMHead
in_features
=
module
.
embedding_dim
out_features
=
module
.
num_embeddings
else
:
raise
ValueError
(
f
"Unable to determine dimensions for module
{
module_name
}
"
)
lora_A
=
torch
.
randn
(
rank
,
in_features
,
dtype
=
lora_dtype
)
torch
.
nn
.
init
.
kaiming_uniform_
(
lora_A
,
a
=
5
**
0.5
)
lora_B
=
torch
.
zeros
(
out_features
,
rank
,
dtype
=
lora_dtype
)
# PEFT style
lora_weights
[
f
"base_model.model.
{
module_name
}
.lora_A.weight"
]
=
lora_A
lora_weights
[
f
"base_model.model.
{
module_name
}
.lora_B.weight"
]
=
lora_B
config_path
=
os
.
path
.
join
(
save_dir
,
"adapter_config.json"
)
with
open
(
config_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
adapter_config
,
f
,
indent
=
2
,
ensure_ascii
=
False
)
weights_path
=
os
.
path
.
join
(
save_dir
,
"adapter_model.safetensors"
)
save_file
(
lora_weights
,
weights_path
)
return
lora_weights
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment