Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82cabf53
Unverified
Commit
82cabf53
authored
Feb 13, 2025
by
Jee Jee Li
Committed by
GitHub
Feb 12, 2025
Browse files
[Misc] Delete unused LoRA modules (#13151)
parent
314cfade
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
8 deletions
+20
-8
tests/lora/test_lora_manager.py
tests/lora/test_lora_manager.py
+12
-6
vllm/lora/models.py
vllm/lora/models.py
+7
-1
vllm/lora/punica_wrapper/punica_base.py
vllm/lora/punica_wrapper/punica_base.py
+1
-1
No files found.
tests/lora/test_lora_manager.py
View file @
82cabf53
...
@@ -606,20 +606,26 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
...
@@ -606,20 +606,26 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
assert
isinstance
(
model
.
get_submodule
(
"gate_up_proj"
),
assert
isinstance
(
model
.
get_submodule
(
"gate_up_proj"
),
MergedColumnParallelLinearWithLoRA
)
MergedColumnParallelLinearWithLoRA
)
# Verify packed lora is correct
model_lora_clone
=
model_lora
.
clone
(
1
)
model_lora_clone1
=
model_lora1
.
clone
(
1
)
assert
manager
.
add_adapter
(
model_lora
)
assert
manager
.
add_adapter
(
model_lora
)
assert
manager
.
add_adapter
(
model_lora1
)
assert
manager
.
add_adapter
(
model_lora1
)
assert
model_lora
.
get_lora
(
"gate_proj"
)
is
None
assert
model_lora
.
get_lora
(
"up_proj"
)
is
None
assert
model_lora1
.
get_lora
(
"up_proj"
)
is
None
packed_lora
=
model_lora
.
get_lora
(
"gate_up_proj"
)
packed_lora
=
model_lora
.
get_lora
(
"gate_up_proj"
)
assert
packed_lora
and
isinstance
(
packed_lora
,
PackedLoRALayerWeights
)
assert
packed_lora
and
isinstance
(
packed_lora
,
PackedLoRALayerWeights
)
torch
.
testing
.
assert_close
(
packed_lora
.
lora_a
[
0
],
torch
.
testing
.
assert_close
(
packed_lora
.
lora_a
[
0
],
model_lora
.
get_lora
(
"gate_proj"
).
lora_a
)
model_lora
_clone
.
get_lora
(
"gate_proj"
).
lora_a
)
torch
.
testing
.
assert_close
(
packed_lora
.
lora_b
[
0
],
torch
.
testing
.
assert_close
(
packed_lora
.
lora_b
[
0
],
model_lora
.
get_lora
(
"gate_proj"
).
lora_b
)
model_lora
_clone
.
get_lora
(
"gate_proj"
).
lora_b
)
torch
.
testing
.
assert_close
(
packed_lora
.
lora_a
[
1
],
torch
.
testing
.
assert_close
(
packed_lora
.
lora_a
[
1
],
model_lora
.
get_lora
(
"up_proj"
).
lora_a
)
model_lora
_clone
.
get_lora
(
"up_proj"
).
lora_a
)
torch
.
testing
.
assert_close
(
packed_lora
.
lora_b
[
1
],
torch
.
testing
.
assert_close
(
packed_lora
.
lora_b
[
1
],
model_lora
.
get_lora
(
"up_proj"
).
lora_b
)
model_lora
_clone
.
get_lora
(
"up_proj"
).
lora_b
)
packed_lora1
=
model_lora1
.
get_lora
(
"gate_up_proj"
)
packed_lora1
=
model_lora1
.
get_lora
(
"gate_up_proj"
)
assert
packed_lora1
and
isinstance
(
packed_lora1
,
PackedLoRALayerWeights
)
assert
packed_lora1
and
isinstance
(
packed_lora1
,
PackedLoRALayerWeights
)
...
@@ -627,6 +633,6 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
...
@@ -627,6 +633,6 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
assert
packed_lora1
.
lora_a
[
0
]
is
None
assert
packed_lora1
.
lora_a
[
0
]
is
None
assert
packed_lora1
.
lora_b
[
0
]
is
None
assert
packed_lora1
.
lora_b
[
0
]
is
None
torch
.
testing
.
assert_close
(
packed_lora1
.
lora_a
[
1
],
torch
.
testing
.
assert_close
(
packed_lora1
.
lora_a
[
1
],
model_lora1
.
get_lora
(
"up_proj"
).
lora_a
)
model_lora
_clone
1
.
get_lora
(
"up_proj"
).
lora_a
)
torch
.
testing
.
assert_close
(
packed_lora1
.
lora_b
[
1
],
torch
.
testing
.
assert_close
(
packed_lora1
.
lora_b
[
1
],
model_lora1
.
get_lora
(
"up_proj"
).
lora_b
)
model_lora
_clone
1
.
get_lora
(
"up_proj"
).
lora_b
)
vllm/lora/models.py
View file @
82cabf53
...
@@ -5,7 +5,8 @@ import math
...
@@ -5,7 +5,8 @@ import math
import
os
import
os
import
re
import
re
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Type
,
Union
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
Type
,
Union
)
import
safetensors.torch
import
safetensors.torch
import
torch
import
torch
...
@@ -619,12 +620,14 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -619,12 +620,14 @@ class LoRAModelManager(AdapterModelManager):
def
_create_merged_loras_inplace
(
self
,
lora_model
:
LoRAModel
)
->
None
:
def
_create_merged_loras_inplace
(
self
,
lora_model
:
LoRAModel
)
->
None
:
for
module_name
,
new_module_names
in
self
.
packed_modules
.
items
():
for
module_name
,
new_module_names
in
self
.
packed_modules
.
items
():
replacement_loras
:
List
[
Optional
[
LoRALayerWeights
]]
=
[]
replacement_loras
:
List
[
Optional
[
LoRALayerWeights
]]
=
[]
replaced_module
:
Set
[
str
]
=
set
()
has_replacement
=
False
has_replacement
=
False
for
r
in
new_module_names
:
for
r
in
new_module_names
:
lora
=
lora_model
.
get_lora
(
r
)
lora
=
lora_model
.
get_lora
(
r
)
replacement_loras
.
append
(
lora
)
replacement_loras
.
append
(
lora
)
if
lora
:
if
lora
:
has_replacement
=
True
has_replacement
=
True
replaced_module
.
add
(
r
)
if
not
has_replacement
:
if
not
has_replacement
:
continue
continue
for
i
in
range
(
len
(
replacement_loras
)):
for
i
in
range
(
len
(
replacement_loras
)):
...
@@ -633,6 +636,9 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -633,6 +636,9 @@ class LoRAModelManager(AdapterModelManager):
replacement_loras
[
i
]
=
None
replacement_loras
[
i
]
=
None
lora_model
.
loras
[
module_name
]
=
PackedLoRALayerWeights
.
pack
(
lora_model
.
loras
[
module_name
]
=
PackedLoRALayerWeights
.
pack
(
replacement_loras
)
replacement_loras
)
# Remove the modules that have been replaced.
for
module
in
replaced_module
:
lora_model
.
loras
.
pop
(
module
,
None
)
def
deactivate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
def
deactivate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
return
deactivate_adapter
(
adapter_id
,
self
.
_active_adapters
,
return
deactivate_adapter
(
adapter_id
,
self
.
_active_adapters
,
...
...
vllm/lora/punica_wrapper/punica_base.py
View file @
82cabf53
...
@@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
...
@@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
# 5 is the number of indic
i
es tensors.
# 5 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices
# embeddings_indices,long_lora_indices
self
.
indices_len
:
List
[
Optional
[
int
]]
=
[
None
]
*
5
self
.
indices_len
:
List
[
Optional
[
int
]]
=
[
None
]
*
5
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment