Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
11d8a091
Unverified
Commit
11d8a091
authored
Jan 01, 2025
by
Jee Jee Li
Committed by
GitHub
Jan 01, 2025
Browse files
[Misc] Optimize Qwen2-VL LoRA test (#11663)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
365801fe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
4 deletions
+21
-4
tests/lora/test_qwen2vl.py
tests/lora/test_qwen2vl.py
+2
-3
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+19
-1
No files found.
tests/lora/test_qwen2vl.py
View file @
11d8a091
...
@@ -7,7 +7,7 @@ from vllm.assets.image import ImageAsset
...
@@ -7,7 +7,7 @@ from vllm.assets.image import ImageAsset
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
MODEL_PATH
=
"Qwen/Qwen2-VL-
7
B-Instruct"
MODEL_PATH
=
"Qwen/Qwen2-VL-
2
B-Instruct"
PROMPT_TEMPLATE
=
(
PROMPT_TEMPLATE
=
(
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>"
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>"
...
@@ -49,10 +49,9 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
...
@@ -49,10 +49,9 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
# Print the outputs.
# Print the outputs.
generated_texts
:
List
[
str
]
=
[]
generated_texts
:
List
[
str
]
=
[]
for
output
in
outputs
:
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
.
strip
()
generated_text
=
output
.
outputs
[
0
].
text
.
strip
()
generated_texts
.
append
(
generated_text
)
generated_texts
.
append
(
generated_text
)
print
(
f
"
Prompt:
{
prompt
!
r
}
,
Generated text:
{
generated_text
!
r
}
"
)
print
(
f
"Generated text:
{
generated_text
!
r
}
"
)
return
generated_texts
return
generated_texts
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
11d8a091
...
@@ -52,6 +52,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
...
@@ -52,6 +52,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig
)
GPTQMarlinConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
ImageItem
,
ModalityData
,
from
vllm.multimodal.inputs
import
(
ImageItem
,
ModalityData
,
MultiModalFieldConfig
,
MultiModalKwargs
,
MultiModalFieldConfig
,
MultiModalKwargs
,
...
@@ -926,15 +927,23 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -926,15 +927,23 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
}
}
# LoRA specific attributes
# LoRA specific attributes
# TODO Support LoRA for the visual encoder in the future.
supported_lora_modules
=
[
supported_lora_modules
=
[
"qkv_proj"
,
"qkv_proj"
,
"o_proj"
,
"o_proj"
,
"gate_up_proj"
,
"gate_up_proj"
,
"down_proj"
,
"down_proj"
,
# vision tower
"qkv"
,
"attn.proj"
,
# Distinguish patch_embed.proj
"fc1"
,
"fc2"
,
# projector
"mlp.0"
,
"mlp.2"
]
]
embedding_modules
=
{}
embedding_modules
=
{}
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
# To ensure correct weight loading and mapping.
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"lm_head."
:
"language_model.lm_head."
,
"lm_head."
:
"language_model.lm_head."
,
...
@@ -1231,3 +1240,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1231,3 +1240,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
"""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"language_model"
,
connector
=
"visual."
,
tower_model
=
"visual.merger."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment