Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d88506dd
Unverified
Commit
d88506dd
authored
Feb 05, 2025
by
Sumit Vij
Committed by
GitHub
Feb 05, 2025
Browse files
[Model] LoRA Support for Ultravox model (#11253)
parent
9cdea30b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
160 additions
and
7 deletions
+160
-7
docs/source/models/supported_models.md
docs/source/models/supported_models.md
+1
-1
tests/conftest.py
tests/conftest.py
+12
-4
tests/lora/test_ultravox.py
tests/lora/test_ultravox.py
+121
-0
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+26
-2
No files found.
docs/source/models/supported_models.md
View file @
d88506dd
...
@@ -857,7 +857,7 @@ See [this page](#generative-models) for more information on how to use generativ
...
@@ -857,7 +857,7 @@ See [this page](#generative-models) for more information on how to use generativ
*
Ultravox
*
Ultravox
*
T + A
<sup>
E+
</sup>
*
T + A
<sup>
E+
</sup>
*
`fixie-ai/ultravox-v0_3`
*
`fixie-ai/ultravox-v0_3`
*
*
✅︎
*
✅︎
*
✅︎
*
✅︎
*
✅︎
:::
:::
...
...
tests/conftest.py
View file @
d88506dd
...
@@ -737,6 +737,7 @@ class VllmRunner:
...
@@ -737,6 +737,7 @@ class VllmRunner:
images
:
Optional
[
PromptImageInput
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
inputs
=
self
.
get_inputs
(
prompts
,
inputs
=
self
.
get_inputs
(
prompts
,
images
=
images
,
images
=
images
,
...
@@ -744,7 +745,8 @@ class VllmRunner:
...
@@ -744,7 +745,8 @@ class VllmRunner:
audios
=
audios
)
audios
=
audios
)
req_outputs
=
self
.
model
.
generate
(
inputs
,
req_outputs
=
self
.
model
.
generate
(
inputs
,
sampling_params
=
sampling_params
)
sampling_params
=
sampling_params
,
**
kwargs
)
outputs
:
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]
=
[]
outputs
:
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]
=
[]
for
req_output
in
req_outputs
:
for
req_output
in
req_outputs
:
...
@@ -782,6 +784,7 @@ class VllmRunner:
...
@@ -782,6 +784,7 @@ class VllmRunner:
images
:
Optional
[
PromptImageInput
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
List
[
TokensTextLogprobs
],
)
->
Union
[
List
[
TokensTextLogprobs
],
List
[
TokensTextLogprobsPromptLogprobs
]]:
List
[
TokensTextLogprobsPromptLogprobs
]]:
inputs
=
self
.
get_inputs
(
prompts
,
inputs
=
self
.
get_inputs
(
prompts
,
...
@@ -790,7 +793,8 @@ class VllmRunner:
...
@@ -790,7 +793,8 @@ class VllmRunner:
audios
=
audios
)
audios
=
audios
)
req_outputs
=
self
.
model
.
generate
(
inputs
,
req_outputs
=
self
.
model
.
generate
(
inputs
,
sampling_params
=
sampling_params
)
sampling_params
=
sampling_params
,
**
kwargs
)
toks_str_logsprobs_prompt_logprobs
=
(
toks_str_logsprobs_prompt_logprobs
=
(
self
.
_final_steps_generate_w_logprobs
(
req_outputs
))
self
.
_final_steps_generate_w_logprobs
(
req_outputs
))
...
@@ -826,13 +830,15 @@ class VllmRunner:
...
@@ -826,13 +830,15 @@ class VllmRunner:
images
:
Optional
[
PromptImageInput
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
outputs
=
self
.
generate
(
prompts
,
outputs
=
self
.
generate
(
prompts
,
greedy_params
,
greedy_params
,
images
=
images
,
images
=
images
,
videos
=
videos
,
videos
=
videos
,
audios
=
audios
)
audios
=
audios
,
**
kwargs
)
return
[(
output_ids
[
0
],
output_str
[
0
])
return
[(
output_ids
[
0
],
output_str
[
0
])
for
output_ids
,
output_str
in
outputs
]
for
output_ids
,
output_str
in
outputs
]
...
@@ -847,6 +853,7 @@ class VllmRunner:
...
@@ -847,6 +853,7 @@ class VllmRunner:
videos
:
Optional
[
PromptVideoInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
List
[
TokensTextLogprobs
],
)
->
Union
[
List
[
TokensTextLogprobs
],
List
[
TokensTextLogprobsPromptLogprobs
]]:
List
[
TokensTextLogprobsPromptLogprobs
]]:
greedy_logprobs_params
=
SamplingParams
(
greedy_logprobs_params
=
SamplingParams
(
...
@@ -861,7 +868,8 @@ class VllmRunner:
...
@@ -861,7 +868,8 @@ class VllmRunner:
greedy_logprobs_params
,
greedy_logprobs_params
,
images
=
images
,
images
=
images
,
audios
=
audios
,
audios
=
audios
,
videos
=
videos
)
videos
=
videos
,
**
kwargs
)
def
generate_encoder_decoder_greedy_logprobs
(
def
generate_encoder_decoder_greedy_logprobs
(
self
,
self
,
...
...
tests/lora/test_ultravox.py
0 → 100644
View file @
d88506dd
import
shutil
from
os
import
path
from
tempfile
import
TemporaryDirectory
from
typing
import
List
,
Tuple
import
torch
from
huggingface_hub
import
snapshot_download
from
safetensors.torch
import
load_file
,
save_file
from
transformers
import
AutoTokenizer
from
vllm.lora.request
import
LoRARequest
from
..models.utils
import
check_outputs_equal
ULTRAVOX_MODEL_NAME
=
"fixie-ai/ultravox-v0_3"
LLMA_MODEL_NAME
=
"meta-llama/Llama-3.1-8B-Instruct"
VLLM_PLACEHOLDER
=
"<|reserved_special_token_0|>"
PROMPT
=
"Tell me about a Fool's mate move in 20 words. Provide the moves!"
def
llama3_1_8b_chess_lora_path
():
return
snapshot_download
(
repo_id
=
"mkopecki/chess-lora-adapter-llama-3.1-8b"
)
# can't use llama lora adapter without module name transformation
# because ultravox nest language model
def
transform_module_names_for_ultravox
(
state_dict
):
transformed_state_dict
=
{}
for
key
,
value
in
state_dict
.
items
():
new_key
=
key
.
replace
(
"base_model.model"
,
"base_model.model.language_model"
)
transformed_state_dict
[
new_key
]
=
value
return
transformed_state_dict
def
mk_llama3_1_8b_ultravox_chess_lora
(
source_repo
,
target_path
):
tensor_file
=
"adapter_model.safetensors"
state_dict
=
load_file
(
path
.
join
(
source_repo
,
tensor_file
))
transformed_state_dict
=
transform_module_names_for_ultravox
(
state_dict
)
save_file
(
transformed_state_dict
,
path
.
join
(
target_path
,
tensor_file
))
config_file
=
"adapter_config.json"
shutil
.
copyfile
(
path
.
join
(
source_repo
,
config_file
),
path
.
join
(
target_path
,
config_file
))
return
target_path
def
_get_prompt
(
audio_count
,
question
,
placeholder
,
model_name
)
->
str
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
placeholder
=
f
"
{
placeholder
}
\n
"
*
audio_count
return
tokenizer
.
apply_chat_template
([{
'role'
:
'user'
,
'content'
:
f
"
{
placeholder
}{
question
}
"
}],
tokenize
=
False
,
add_generation_prompt
=
True
)
def
test_ultravox_lora
(
vllm_runner
):
"""
TODO: Train an Ultravox LoRA instead of using a Llama LoRA.
"""
# Workaround to prevent device mismatch in Whisper.
# Can be removed when it is fixed upstream in transformer
# https://github.com/huggingface/transformers/pull/35866
torch
.
set_default_device
(
"cpu"
)
llama3_1_8b_chess_lora
=
llama3_1_8b_chess_lora_path
()
with
TemporaryDirectory
()
as
temp_ultravox_lora_dir
:
llama3_1_8b_ultravox_chess_lora
=
mk_llama3_1_8b_ultravox_chess_lora
(
llama3_1_8b_chess_lora
,
temp_ultravox_lora_dir
)
with
vllm_runner
(
ULTRAVOX_MODEL_NAME
,
enforce_eager
=
True
,
max_num_seqs
=
2
,
enable_lora
=
True
,
max_loras
=
1
,
max_lora_rank
=
128
,
dtype
=
"bfloat16"
,
max_model_len
=
1024
,
)
as
vllm_model
:
ultravox_outputs
:
List
[
Tuple
[
List
[
int
],
str
]]
=
vllm_model
.
generate_greedy
(
[
_get_prompt
(
0
,
PROMPT
,
VLLM_PLACEHOLDER
,
ULTRAVOX_MODEL_NAME
)
],
256
,
lora_request
=
LoRARequest
(
str
(
1
),
1
,
llama3_1_8b_ultravox_chess_lora
),
)
# run llama with and without lora to compare outputs with above
with
vllm_runner
(
LLMA_MODEL_NAME
,
enforce_eager
=
True
,
max_num_seqs
=
2
,
enable_lora
=
True
,
max_loras
=
1
,
max_lora_rank
=
128
,
dtype
=
"bfloat16"
,
max_model_len
=
1024
,
)
as
vllm_model
:
llama_outputs
:
List
[
Tuple
[
List
[
int
],
str
]]
=
(
vllm_model
.
generate_greedy
(
[
_get_prompt
(
0
,
PROMPT
,
VLLM_PLACEHOLDER
,
LLMA_MODEL_NAME
)],
256
,
lora_request
=
LoRARequest
(
str
(
1
),
1
,
llama3_1_8b_chess_lora
),
))
check_outputs_equal
(
outputs_0_lst
=
ultravox_outputs
,
outputs_1_lst
=
llama_outputs
,
name_0
=
"ultravox"
,
name_1
=
"llama"
,
)
vllm/model_executor/models/ultravox.py
View file @
d88506dd
...
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
...
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalFieldConfig
,
MultiModalKwargs
,
from
vllm.multimodal.inputs
import
(
MultiModalFieldConfig
,
MultiModalKwargs
,
...
@@ -33,7 +34,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
...
@@ -33,7 +34,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
,
merge_multimodal_embeddings
,
...
@@ -343,7 +344,20 @@ class ModifiedWhisperEncoder(WhisperEncoder):
...
@@ -343,7 +344,20 @@ class ModifiedWhisperEncoder(WhisperEncoder):
UltravoxMultiModalProcessor
,
UltravoxMultiModalProcessor
,
info
=
UltravoxProcessingInfo
,
info
=
UltravoxProcessingInfo
,
dummy_inputs
=
UltravoxDummyInputsBuilder
)
dummy_inputs
=
UltravoxDummyInputsBuilder
)
class
UltravoxModel
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
class
UltravoxModel
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
# LoRA specific attributes
# TODO : Add LoRA to the audio tower and projector.
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
hf_to_vllm_mapper
=
WeightsMapper
(
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"audio_tower.model.encoder."
:
"audio_tower."
})
orig_to_new_prefix
=
{
"audio_tower.model.encoder."
:
"audio_tower."
})
...
@@ -391,6 +405,16 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -391,6 +405,16 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
return
get_sampler
()
return
get_sampler
()
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
"""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"language_model."
,
connector
=
"multi_modal_projector."
,
tower_model
=
"audio_tower."
,
)
def
_audio_features_to_embeddings
(
def
_audio_features_to_embeddings
(
self
,
input_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
,
input_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
audio_input
=
input_features
.
to
(
self
.
audio_tower
.
dtype
)
audio_input
=
input_features
.
to
(
self
.
audio_tower
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment