Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2be8ec6e
Unverified
Commit
2be8ec6e
authored
Sep 03, 2024
by
Peter Salas
Committed by
GitHub
Sep 04, 2024
Browse files
[Model] Add Ultravox support for multiple audio chunks (#7963)
parent
e16fa99a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
198 additions
and
115 deletions
+198
-115
examples/offline_inference_audio_language.py
examples/offline_inference_audio_language.py
+34
-24
tests/models/test_ultravox.py
tests/models/test_ultravox.py
+77
-26
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+87
-65
No files found.
examples/offline_inference_audio_language.py
View file @
2be8ec6e
...
@@ -11,25 +11,33 @@ from vllm import LLM, SamplingParams
...
@@ -11,25 +11,33 @@ from vllm import LLM, SamplingParams
from
vllm.assets.audio
import
AudioAsset
from
vllm.assets.audio
import
AudioAsset
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
# Input audio and question
audio_assets
=
[
AudioAsset
(
"mary_had_lamb"
),
AudioAsset
(
"winning_call"
)]
audio_and_sample_rate
=
AudioAsset
(
"mary_had_lamb"
).
audio_and_sample_rate
question_per_audio_count
=
[
question
=
"What is recited in the audio?"
"What is recited in the audio?"
,
"What sport and what nursery rhyme are referenced?"
]
# Ultravox 0.3
# Ultravox 0.3
def
run_ultravox
(
question
):
def
run_ultravox
(
question
,
audio_count
):
model_name
=
"fixie-ai/ultravox-v0_3"
model_name
=
"fixie-ai/ultravox-v0_3"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
messages
=
[{
messages
=
[{
'role'
:
'user'
,
'role'
:
'content'
:
f
"<|reserved_special_token_0|>
\n
{
question
}
"
'user'
,
'content'
:
"<|reserved_special_token_0|>
\n
"
*
audio_count
+
question
}]
}]
prompt
=
tokenizer
.
apply_chat_template
(
messages
,
prompt
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
False
,
tokenize
=
False
,
add_generation_prompt
=
True
)
add_generation_prompt
=
True
)
llm
=
LLM
(
model
=
model_name
)
llm
=
LLM
(
model
=
model_name
,
enforce_eager
=
True
,
enable_chunked_prefill
=
False
,
max_model_len
=
8192
,
limit_mm_per_prompt
=
{
"audio"
:
audio_count
})
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
@@ -44,7 +52,9 @@ def main(args):
...
@@ -44,7 +52,9 @@ def main(args):
if
model
not
in
model_example_map
:
if
model
not
in
model_example_map
:
raise
ValueError
(
f
"Model type
{
model
}
is not supported."
)
raise
ValueError
(
f
"Model type
{
model
}
is not supported."
)
llm
,
prompt
,
stop_token_ids
=
model_example_map
[
model
](
question
)
audio_count
=
args
.
num_audios
llm
,
prompt
,
stop_token_ids
=
model_example_map
[
model
](
question_per_audio_count
[
audio_count
-
1
],
audio_count
)
# We set temperature to 0.2 so that outputs can be different
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
# even when all prompts are identical when running batch inference.
...
@@ -53,23 +63,18 @@ def main(args):
...
@@ -53,23 +63,18 @@ def main(args):
stop_token_ids
=
stop_token_ids
)
stop_token_ids
=
stop_token_ids
)
assert
args
.
num_prompts
>
0
assert
args
.
num_prompts
>
0
i
f
args
.
num_prompts
==
1
:
i
nputs
=
{
# Single inference
"prompt"
:
prompt
,
inputs
=
{
"multi_modal_data"
:
{
"
prompt"
:
prompt
,
"
audio"
:
[
"multi_modal_data"
:
{
asset
.
audio_and_sample_rate
"audio"
:
audio_and_sample_rate
for
asset
in
audio_assets
[:
audio_count
]
},
]
}
}
,
}
else
:
if
args
.
num_prompts
>
1
:
# Batch inference
# Batch inference
inputs
=
[{
inputs
=
[
inputs
]
*
args
.
num_prompts
"prompt"
:
prompt
,
"multi_modal_data"
:
{
"audio"
:
audio_and_sample_rate
},
}
for
_
in
range
(
args
.
num_prompts
)]
outputs
=
llm
.
generate
(
inputs
,
sampling_params
=
sampling_params
)
outputs
=
llm
.
generate
(
inputs
,
sampling_params
=
sampling_params
)
...
@@ -92,6 +97,11 @@ if __name__ == "__main__":
...
@@ -92,6 +97,11 @@ if __name__ == "__main__":
type
=
int
,
type
=
int
,
default
=
1
,
default
=
1
,
help
=
'Number of prompts to run.'
)
help
=
'Number of prompts to run.'
)
parser
.
add_argument
(
"--num-audios"
,
type
=
int
,
default
=
1
,
choices
=
[
1
,
2
],
help
=
"Number of audio items per prompt."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
tests/models/test_ultravox.py
View file @
2be8ec6e
...
@@ -16,37 +16,32 @@ MODEL_NAME = "fixie-ai/ultravox-v0_3"
...
@@ -16,37 +16,32 @@ MODEL_NAME = "fixie-ai/ultravox-v0_3"
AudioTuple
=
Tuple
[
np
.
ndarray
,
int
]
AudioTuple
=
Tuple
[
np
.
ndarray
,
int
]
VLLM_PLACEHOLDER
=
"<|reserved_special_token_0|>"
HF_PLACEHOLDER
=
"<|audio|>"
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
audio_a
nd_sample_rate
():
def
audio_a
ssets
():
from
vllm.assets.audio
import
AudioAsset
from
vllm.assets.audio
import
AudioAsset
return
AudioAsset
(
"mary_had_lamb"
)
.
a
udio
_and_sample_rate
return
[
AudioAsset
(
"mary_had_lamb"
)
,
A
udio
Asset
(
"winning_call"
)]
@
pytest
.
fixture
@
pytest
.
fixture
(
scope
=
"module"
,
params
=
(
"mary_had_lamb"
,
"winning_call"
))
def
prompts_and_audios
(
audio_and_sample_rate
):
def
audio
(
request
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_NAME
)
from
vllm.assets.audio
import
AudioAsset
return
AudioAsset
(
request
.
param
)
vllm_placeholder
=
"<|reserved_special_token_0|>"
hf_placeholder
=
"<|audio|>"
question
=
"What's in the audio?"
def
_get_prompt
(
audio_count
,
question
,
placeholder
):
vllm_prompt
=
tokenizer
.
apply_chat_template
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_NAME
)
[{
placeholder
=
f
"
{
placeholder
}
\n
"
*
audio_count
'role'
:
'user'
,
'content'
:
f
"
{
vllm_placeholder
}
\n
{
question
}
"
}],
tokenize
=
False
,
add_generation_prompt
=
True
)
hf_prompt
=
tokenizer
.
apply_chat_template
(
[{
'role'
:
'user'
,
'content'
:
f
"
{
hf_placeholder
}
\n
{
question
}
"
}],
tokenize
=
False
,
add_generation_prompt
=
True
)
return
[(
vllm_prompt
,
hf_prompt
,
audio_and_sample_rate
)]
return
tokenizer
.
apply_chat_template
([{
'role'
:
'user'
,
'content'
:
f
"
{
placeholder
}{
question
}
"
}],
tokenize
=
False
,
add_generation_prompt
=
True
)
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
...
@@ -134,15 +129,71 @@ def run_test(
...
@@ -134,15 +129,71 @@ def run_test(
)
)
def
run_multi_audio_test
(
vllm_runner
:
Type
[
VllmRunner
],
prompts_and_audios
:
List
[
Tuple
[
str
,
List
[
AudioTuple
]]],
model
:
str
,
*
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
with
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
,
limit_mm_per_prompt
=
{
"audio"
:
max
((
len
(
audio
)
for
_
,
audio
in
prompts_and_audios
))
})
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
[
prompt
for
prompt
,
_
in
prompts_and_audios
],
max_tokens
,
num_logprobs
=
num_logprobs
,
audios
=
[
audios
for
_
,
audios
in
prompts_and_audios
])
# The HuggingFace model doesn't support multiple audios yet, so
# just assert that some tokens were generated.
assert
all
(
tokens
for
tokens
,
*
_
in
vllm_outputs
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
prompts_and_audios
,
dtype
:
str
,
def
test_models
(
hf_runner
,
vllm_runner
,
audio
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
num_logprobs
:
int
)
->
None
:
vllm_prompt
=
_get_prompt
(
1
,
"Describe the audio above."
,
VLLM_PLACEHOLDER
)
hf_prompt
=
_get_prompt
(
1
,
"Describe the audio above."
,
HF_PLACEHOLDER
)
run_test
(
run_test
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
prompts_and_audios
,
[(
vllm_prompt
,
hf_prompt
,
audio
.
audio_and_sample_rate
)],
MODEL_NAME
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models_with_multiple_audios
(
vllm_runner
,
audio_assets
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
vllm_prompt
=
_get_prompt
(
len
(
audio_assets
),
"Describe each of the audios above."
,
VLLM_PLACEHOLDER
)
run_multi_audio_test
(
vllm_runner
,
[(
vllm_prompt
,
[
audio
.
audio_and_sample_rate
for
audio
in
audio_assets
])],
MODEL_NAME
,
MODEL_NAME
,
dtype
=
dtype
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
...
...
vllm/model_executor/models/ultravox.py
View file @
2be8ec6e
...
@@ -29,12 +29,12 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -29,12 +29,12 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.utils
import
(
filter_weights
,
from
vllm.model_executor.models.utils
import
(
filter_weights
,
flatten_bn
,
init_vllm_registered_model
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.base
import
MultiModalInputs
,
NestedTensors
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
...
@@ -48,13 +48,14 @@ logger = init_logger(__name__)
...
@@ -48,13 +48,14 @@ logger = init_logger(__name__)
class
UltravoxAudioFeatureInputs
(
TypedDict
):
class
UltravoxAudioFeatureInputs
(
TypedDict
):
type
:
Literal
[
"audio_features"
]
type
:
Literal
[
"audio_features"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
data
:
Nested
Tensor
s
"""Shape: `(batch_size
*
num_audios, 80, M)"""
"""Shape: `(batch_size
,
num_audios, 80, M)"""
class
UltravoxAudioEmbeddingInputs
(
TypedDict
):
class
UltravoxAudioEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"audio_embeds"
]
type
:
Literal
[
"audio_embeds"
]
data
:
torch
.
Tensor
data
:
NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
UltravoxAudioInputs
=
Union
[
UltravoxAudioFeatureInputs
,
UltravoxAudioInputs
=
Union
[
UltravoxAudioFeatureInputs
,
...
@@ -85,24 +86,33 @@ def dummy_data_for_ultravox(
...
@@ -85,24 +86,33 @@ def dummy_data_for_ultravox(
audio_count
=
mm_counts
[
"audio"
]
audio_count
=
mm_counts
[
"audio"
]
audio_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
audio_placeholder
=
array
(
_AUDIO_PLACEHOLDER_TOKEN
VLLM_TOKEN_ID_ARRAY_TYPE
,
])
*
get_ultravox_max_audio_tokens
(
ctx
)
*
audio_count
[
_AUDIO_PLACEHOLDER_TOKEN
])
*
get_ultravox_max_audio_tokens
(
ctx
)
# Add a separator between each chunk.
audio_token_ids
=
(
audio_placeholder
+
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]))
*
audio_count
other_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
other_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
len
(
audio_token_ids
))
[
0
])
*
(
seq_len
-
len
(
audio_token_ids
))
audio_and_sr
=
(
np
.
array
([
0.0
]
*
feature_extractor
.
chunk_length
),
1
)
audio_and_sr
=
(
np
.
array
([
0.0
]
*
feature_extractor
.
chunk_length
),
1
)
mm_dict
=
{
mm_dict
=
{
"audio"
:
[
audio_and_sr
]
*
audio_count
}
"audio"
:
audio_and_sr
if
audio_count
==
1
else
[
audio_and_sr
]
*
audio_count
}
return
(
SequenceData
(
audio_token_ids
+
other_token_ids
),
mm_dict
)
return
(
SequenceData
(
audio_token_ids
+
other_token_ids
),
mm_dict
)
def
input_mapper_for_ultravox
(
ctx
:
InputContext
,
data
:
object
):
def
input_mapper_for_ultravox
(
ctx
:
InputContext
,
data
:
object
):
if
isinstance
(
data
,
tuple
):
if
not
isinstance
(
data
,
list
):
(
audio
,
sr
)
=
cast
(
Tuple
[
np
.
ndarray
,
Union
[
float
,
int
]],
data
)
data
=
[
data
]
audio_features
=
[]
for
audio_input
in
data
:
if
not
isinstance
(
audio_input
,
tuple
):
raise
NotImplementedError
(
f
"Unsupported data type:
{
type
(
audio_input
)
}
"
)
(
audio
,
sr
)
=
cast
(
Tuple
[
np
.
ndarray
,
Union
[
float
,
int
]],
audio_input
)
feature_extractor
=
whisper_feature_extractor
(
ctx
)
feature_extractor
=
whisper_feature_extractor
(
ctx
)
if
sr
!=
feature_extractor
.
sampling_rate
:
if
sr
!=
feature_extractor
.
sampling_rate
:
...
@@ -121,15 +131,14 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
...
@@ -121,15 +131,14 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
# Not enough audio; pad it.
# Not enough audio; pad it.
audio
=
np
.
pad
(
audio
,
(
0
,
minimum_audio_length
-
len
(
audio
)))
audio
=
np
.
pad
(
audio
,
(
0
,
minimum_audio_length
-
len
(
audio
)))
return
MultiModalInputs
({
single_audio_features
=
feature_extractor
(
"audio_features"
:
audio
,
sampling_rate
=
sr
,
padding
=
"longest"
,
feature_extractor
(
audio
,
return_tensors
=
"pt"
)[
"input_features"
]
sampling_rate
=
sr
,
padding
=
"longest"
,
return_tensors
=
"pt"
)[
"input_features"
]
})
raise
NotImplementedError
(
f
"Unsupported data type:
{
type
(
data
)
}
"
)
# Remove the batch dimension because we're wrapping it in a list.
audio_features
.
append
(
single_audio_features
.
squeeze
(
0
))
return
MultiModalInputs
({
"audio_features"
:
audio_features
})
def
input_processor_for_ultravox
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
def
input_processor_for_ultravox
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
...
@@ -138,25 +147,31 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -138,25 +147,31 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
return
llm_inputs
feature_extractor
=
whisper_feature_extractor
(
ctx
)
feature_extractor
=
whisper_feature_extractor
(
ctx
)
audio_data
,
sample_rate
=
multi_modal_data
[
"audio"
]
audios
=
multi_modal_data
[
"audio"
]
if
not
isinstance
(
audios
,
list
):
audio_length
=
audio_data
.
shape
[
0
]
audios
=
[
audios
]
if
sample_rate
!=
feature_extractor
.
sampling_rate
:
# Account for resampling.
audio_token_counts
=
[]
adjustment
=
feature_extractor
.
sampling_rate
/
sample_rate
for
audio_data
,
sample_rate
in
audios
:
audio_length
=
math
.
ceil
(
adjustment
*
audio_length
)
audio_length
=
audio_data
.
shape
[
0
]
if
sample_rate
!=
feature_extractor
.
sampling_rate
:
feature_extractor_output_length
=
math
.
ceil
(
# Account for resampling.
(
audio_length
-
adjustment
=
feature_extractor
.
sampling_rate
/
sample_rate
(
feature_extractor
.
hop_length
-
1
))
/
feature_extractor
.
hop_length
)
audio_length
=
math
.
ceil
(
adjustment
*
audio_length
)
uv_config
=
ctx
.
get_hf_config
(
UltravoxConfig
)
feature_extractor_output_length
=
math
.
ceil
(
audio_num_tokens
=
min
(
(
audio_length
-
(
feature_extractor
.
hop_length
-
1
))
/
max
(
feature_extractor
.
hop_length
)
1
,
math
.
ceil
(
feature_extractor_output_length
/
uv_config
=
ctx
.
get_hf_config
(
UltravoxConfig
)
(
uv_config
.
stack_factor
*
2
))),
audio_num_tokens
=
min
(
get_ultravox_max_audio_tokens
(
ctx
))
max
(
1
,
math
.
ceil
(
feature_extractor_output_length
/
(
uv_config
.
stack_factor
*
2
))),
get_ultravox_max_audio_tokens
(
ctx
))
audio_token_counts
.
append
(
audio_num_tokens
)
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
)
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
)
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
...
@@ -164,7 +179,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -164,7 +179,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
llm_inputs
.
get
(
"prompt"
),
llm_inputs
.
get
(
"prompt"
),
llm_inputs
[
"prompt_token_ids"
],
llm_inputs
[
"prompt_token_ids"
],
placeholder_token_id
=
_AUDIO_PLACEHOLDER_TOKEN
,
placeholder_token_id
=
_AUDIO_PLACEHOLDER_TOKEN
,
repeat_count
=
audio_
num_
tokens
,
repeat_count
=
audio_token
_count
s
,
)
)
# NOTE: Create a defensive copy of the original inputs
# NOTE: Create a defensive copy of the original inputs
...
@@ -338,45 +353,52 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
...
@@ -338,45 +353,52 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of audio features. "
raise
ValueError
(
"Incorrect type of audio features. "
f
"Got type:
{
type
(
audio_features
)
}
"
)
f
"Got type:
{
type
(
audio_features
)
}
"
)
# Remove the N dimension until multiple audios are supported.
if
isinstance
(
audio_features
,
torch
.
Tensor
):
audio_features
=
audio_features
.
squeeze
(
1
)
else
:
audio_features
=
[
t
.
squeeze
(
0
)
for
t
in
audio_features
]
return
UltravoxAudioFeatureInputs
(
type
=
"audio_features"
,
return
UltravoxAudioFeatureInputs
(
type
=
"audio_features"
,
data
=
audio_features
)
data
=
audio_features
)
if
audio_embeds
is
not
None
:
if
audio_embeds
is
not
None
:
if
not
isinstance
(
audio_embeds
,
torch
.
Tensor
):
if
not
isinstance
(
audio_embeds
,
(
torch
.
Tensor
,
list
)
):
raise
ValueError
(
"Incorrect type of audio embeds. "
raise
ValueError
(
"Incorrect type of audio embeds. "
f
"Got type:
{
type
(
audio_embeds
)
}
"
)
f
"Got type:
{
type
(
audio_embeds
)
}
"
)
# Remove the N dimension until multiple audios are supported.
audio_embeds
=
audio_embeds
.
squeeze
(
1
)
return
UltravoxAudioEmbeddingInputs
(
type
=
"audio_embeds"
,
return
UltravoxAudioEmbeddingInputs
(
type
=
"audio_embeds"
,
data
=
audio_embeds
)
data
=
audio_embeds
)
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_process_audio_input
(
def
_process_audio_input
(
self
,
audio_input
:
UltravoxAudioInputs
self
,
audio_input
:
UltravoxAudioInputs
)
->
NestedTensors
:
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
if
audio_input
[
"type"
]
==
"audio_embeds"
:
if
audio_input
[
"type"
]
==
"audio_embeds"
:
return
audio_input
[
"data"
]
return
audio_input
[
"data"
]
audio_features
=
audio_input
[
"data"
]
audio_features
=
audio_input
[
"data"
]
if
isinstance
(
audio_features
,
list
):
if
isinstance
(
audio_features
,
torch
.
Tensor
):
# TODO: Batch these through the encoder/projector instead of
# Combine the B and N dimensions for the encoder/projector
# serializing them.
flattened
=
flatten_bn
(
audio_features
)
return
[
flattened_embeddings
=
self
.
_audio_features_to_embeddings
(
self
.
_audio_features_to_embeddings
(
flattened
)
features
.
unsqueeze
(
0
)).
squeeze
(
0
)
for
features
in
audio_features
# Restore the original dimensions
]
embeddings
=
flattened_embeddings
.
unflatten
(
else
:
0
,
audio_features
.
shape
[:
2
])
return
self
.
_audio_features_to_embeddings
(
audio_features
)
return
embeddings
result
=
[]
# TODO: Batch heterogeneous tensors through the encoder/projector
for
audio_features_item
in
audio_features
:
if
isinstance
(
audio_features_item
,
torch
.
Tensor
):
result
.
append
(
self
.
_audio_features_to_embeddings
(
audio_features_item
))
else
:
embeddings
=
[
# Add a batch dimension to embed it, then remove it.
self
.
_audio_features_to_embeddings
(
tensor
.
unsqueeze
(
0
)
).
squeeze
(
0
)
for
tensor
in
audio_features_item
]
result
.
append
(
embeddings
)
return
result
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
...
@@ -393,7 +415,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
...
@@ -393,7 +415,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
with the `input_ids`.
with the `input_ids`.
Args:
Args:
input
_features: A batch of audio inputs
,
[
1
, 80, M].
audio
_features: A batch of audio inputs [
B, N
, 80, M].
"""
"""
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
if
audio_input
is
not
None
:
if
audio_input
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment