Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db7c8ca9
Unverified
Commit
db7c8ca9
authored
Mar 18, 2025
by
Jee Jee Li
Committed by
GitHub
Mar 18, 2025
Browse files
[Misc] Embedding model support LoRA (#14935)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
f863ffc9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
2 deletions
+30
-2
vllm/lora/models.py
vllm/lora/models.py
+30
-2
No files found.
vllm/lora/models.py
View file @
db7c8ca9
...
@@ -30,6 +30,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
...
@@ -30,6 +30,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules
,
is_regex_target_modules
,
parse_fine_tuned_lora_name
,
replace_submodule
)
parse_fine_tuned_lora_name
,
replace_submodule
)
from
vllm.model_executor.models
import
SupportsLoRA
,
supports_multimodal
from
vllm.model_executor.models
import
SupportsLoRA
,
supports_multimodal
from
vllm.model_executor.models.interfaces
import
is_pooling_model
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.utils
import
PPMissingLayer
,
WeightsMapper
from
vllm.model_executor.models.utils
import
PPMissingLayer
,
WeightsMapper
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
...
@@ -104,6 +105,9 @@ class LoRAModel(AdapterModel):
...
@@ -104,6 +105,9 @@ class LoRAModel(AdapterModel):
"""Get LoRA for a given module by name"""
"""Get LoRA for a given module by name"""
return
self
.
loras
.
get
(
module_name
,
None
)
return
self
.
loras
.
get
(
module_name
,
None
)
def
check_lora_name
(
self
,
lora_name
:
str
)
->
bool
:
return
lora_name
in
self
.
loras
# (yard1): TODO see if we can derive target_embedding_padding automatically
# (yard1): TODO see if we can derive target_embedding_padding automatically
@
classmethod
@
classmethod
def
from_lora_tensors
(
def
from_lora_tensors
(
...
@@ -335,6 +339,7 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -335,6 +339,7 @@ class LoRAModelManager(AdapterModelManager):
# Used for long context lora.
# Used for long context lora.
self
.
scaling_factor_to_offset
:
Dict
[
float
,
int
]
=
{}
self
.
scaling_factor_to_offset
:
Dict
[
float
,
int
]
=
{}
super
().
__init__
(
model
)
super
().
__init__
(
model
)
self
.
supported_lora_modules
=
get_supported_lora_modules
(
self
.
model
)
self
.
supported_lora_modules
=
get_supported_lora_modules
(
self
.
model
)
assert
self
.
supported_lora_modules
,
"No supported LoRA modules found in"
assert
self
.
supported_lora_modules
,
"No supported LoRA modules found in"
f
"
{
self
.
model
.
__class__
.
__name__
}
."
f
"
{
self
.
model
.
__class__
.
__name__
}
."
...
@@ -350,6 +355,7 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -350,6 +355,7 @@ class LoRAModelManager(AdapterModelManager):
# In case the model only supports LoRA for
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
# text modules (e.g. ChatGLM)
and
hasattr
(
self
.
model
,
"get_mm_mapping"
))
and
hasattr
(
self
.
model
,
"get_mm_mapping"
))
self
.
is_pooling_model
=
is_pooling_model
(
self
.
model
)
self
.
packed_modules
:
Dict
[
str
,
List
[
str
]]
=
{}
self
.
packed_modules
:
Dict
[
str
,
List
[
str
]]
=
{}
self
.
modules
:
Dict
[
str
,
BaseLayerWithLoRA
]
=
{}
self
.
modules
:
Dict
[
str
,
BaseLayerWithLoRA
]
=
{}
# Dict instead of a Set for compatibility with LRUCache.
# Dict instead of a Set for compatibility with LRUCache.
...
@@ -389,7 +395,7 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -389,7 +395,7 @@ class LoRAModelManager(AdapterModelManager):
lora_model
.
id
,
index
)
lora_model
.
id
,
index
)
self
.
lora_index_to_id
[
index
]
=
lora_model
.
id
self
.
lora_index_to_id
[
index
]
=
lora_model
.
id
for
module_name
,
module
in
self
.
modules
.
items
():
for
module_name
,
module
in
self
.
modules
.
items
():
module_lora
=
lora_model
.
get_lora
(
module_name
)
module_lora
=
self
.
_get_lora_layer_weights
(
lora_model
,
module_name
)
if
module_lora
:
if
module_lora
:
module_lora
.
optimize
()
module_lora
.
optimize
()
# Bias is not explicitly enabled with the flag enable_lora_bias.
# Bias is not explicitly enabled with the flag enable_lora_bias.
...
@@ -626,7 +632,7 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -626,7 +632,7 @@ class LoRAModelManager(AdapterModelManager):
replaced_module
:
Set
[
str
]
=
set
()
replaced_module
:
Set
[
str
]
=
set
()
has_replacement
=
False
has_replacement
=
False
for
r
in
new_module_names
:
for
r
in
new_module_names
:
lora
=
lora_model
.
get_lora
(
r
)
lora
=
self
.
_get_lora_layer_weights
(
lora_model
,
r
)
replacement_loras
.
append
(
lora
)
replacement_loras
.
append
(
lora
)
if
lora
:
if
lora
:
has_replacement
=
True
has_replacement
=
True
...
@@ -637,12 +643,34 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -637,12 +643,34 @@ class LoRAModelManager(AdapterModelManager):
if
replacement_loras
[
i
]:
if
replacement_loras
[
i
]:
continue
continue
replacement_loras
[
i
]
=
None
replacement_loras
[
i
]
=
None
# HACK Temporary solution for the pool model.
if
self
.
is_pooling_model
and
not
lora_model
.
check_lora_name
(
module_name
):
replaced_module_name
=
module_name
.
replace
(
"model."
,
""
)
if
lora_model
.
check_lora_name
(
module_name
):
module_name
=
replaced_module_name
lora_model
.
loras
[
module_name
]
=
PackedLoRALayerWeights
.
pack
(
lora_model
.
loras
[
module_name
]
=
PackedLoRALayerWeights
.
pack
(
replacement_loras
)
replacement_loras
)
# Remove the modules that have been replaced.
# Remove the modules that have been replaced.
for
module
in
replaced_module
:
for
module
in
replaced_module
:
lora_model
.
loras
.
pop
(
module
,
None
)
lora_model
.
loras
.
pop
(
module
,
None
)
def
_get_lora_layer_weights
(
self
,
lora_model
:
LoRAModel
,
module_name
:
str
)
->
Optional
[
LoRALayerWeights
]:
org_module_name
=
module_name
if
self
.
is_pooling_model
and
not
lora_model
.
check_lora_name
(
module_name
):
# If it's a pool model, and the layer name is not found,
# remove the prefix 'model.' and search again.
module_name
=
module_name
.
replace
(
"model."
,
""
)
if
lora_model
.
check_lora_name
(
module_name
):
org_module_name
=
module_name
logger
.
info_once
(
"For the pool model, successfully loaded the LoRA weights "
"after removing the prefix 'model.'."
)
return
lora_model
.
get_lora
(
org_module_name
)
def
deactivate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
def
deactivate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
return
deactivate_adapter
(
adapter_id
,
self
.
_active_adapters
,
return
deactivate_adapter
(
adapter_id
,
self
.
_active_adapters
,
self
.
_deactivate_adapter
)
self
.
_deactivate_adapter
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment