Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
80e78d02
Unverified
Commit
80e78d02
authored
Mar 11, 2025
by
Farzad Abdolhosseini
Committed by
GitHub
Mar 12, 2025
Browse files
[Model] Extend Ultravox to accept audio longer than 30s (#13631)
Signed-off-by:
Farzad Abdolhosseini
<
farzad@fixie.ai
>
parent
4a42b9f5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
233 additions
and
81 deletions
+233
-81
tests/models/decoder_only/audio_language/test_ultravox.py
tests/models/decoder_only/audio_language/test_ultravox.py
+1
-1
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+50
-7
tests/models/registry.py
tests/models/registry.py
+1
-2
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+181
-71
No files found.
tests/models/decoder_only/audio_language/test_ultravox.py
View file @
80e78d02
...
...
@@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner
from
....utils
import
RemoteOpenAIServer
from
...utils
import
check_logprobs_close
MODEL_NAME
=
"fixie-ai/ultravox-v0_
4
"
MODEL_NAME
=
"fixie-ai/ultravox-v0_
5-llama-3_2-1b
"
AudioTuple
=
tuple
[
np
.
ndarray
,
int
]
...
...
tests/models/multimodal/processing/test_common.py
View file @
80e78d02
# SPDX-License-Identifier: Apache-2.0
import
copy
from
functools
import
partial
from
typing
import
Optional
import
numpy
as
np
import
pytest
...
...
@@ -21,6 +23,7 @@ def _test_processing_correctness(
hit_rate
:
float
,
num_batches
:
int
,
simplify_rate
:
float
,
ignore_mm_keys
:
Optional
[
list
[
str
]]
=
None
,
):
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model_id
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
...
...
@@ -123,7 +126,9 @@ def _test_processing_correctness(
hf_processor_mm_kwargs
=
{},
)
assert
baseline_result
==
cached_result
,
(
assert
_drop_mm_kwargs_keys
(
baseline_result
,
ignore_mm_keys
)
==
_drop_mm_kwargs_keys
(
cached_result
,
ignore_mm_keys
),
(
f
"Failed (
{
batch_idx
=
}
,
{
prompt
=
}
,
{
mm_data
=
}
)"
)
baseline_tokenized_result
=
baseline_processor
.
apply
(
...
...
@@ -132,7 +137,9 @@ def _test_processing_correctness(
hf_processor_mm_kwargs
=
{},
)
assert
baseline_result
==
baseline_tokenized_result
,
(
assert
_drop_mm_kwargs_keys
(
baseline_result
,
ignore_mm_keys
)
==
_drop_mm_kwargs_keys
(
baseline_tokenized_result
,
ignore_mm_keys
),
(
f
"Failed (
{
batch_idx
=
}
,
{
prompt
=
}
,
{
mm_data
=
}
)"
)
cached_tokenized_result
=
cached_processor
.
apply
(
...
...
@@ -141,7 +148,9 @@ def _test_processing_correctness(
hf_processor_mm_kwargs
=
{},
)
assert
cached_result
==
cached_tokenized_result
,
(
assert
_drop_mm_kwargs_keys
(
cached_result
,
ignore_mm_keys
)
==
_drop_mm_kwargs_keys
(
cached_tokenized_result
,
ignore_mm_keys
),
(
f
"Failed (
{
batch_idx
=
}
,
{
prompt
=
}
,
{
mm_data
=
}
)"
)
...
...
@@ -173,7 +182,7 @@ def _test_processing_correctness(
"Qwen/Qwen2-VL-2B-Instruct"
,
"Qwen/Qwen2.5-VL-3B-Instruct"
,
"Qwen/Qwen2-Audio-7B-Instruct"
,
"fixie-ai/ultravox-v0_
4
"
,
"fixie-ai/ultravox-v0_
5-llama-3_2-1b
"
,
"openai/whisper-large-v3"
,
"google/paligemma-3b-mix-224"
,
"google/paligemma2-3b-ft-docci-448"
,
...
...
@@ -188,11 +197,19 @@ def test_processing_correctness(
num_batches
:
int
,
simplify_rate
:
float
,
):
ignore_mm_keys
=
None
if
'ultravox'
in
model_id
:
# In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since
# attention_mask lets us ignore the difference.
ignore_mm_keys
=
[
'audio_features'
]
_test_processing_correctness
(
model_id
,
hit_rate
=
hit_rate
,
num_batches
=
num_batches
,
simplify_rate
=
simplify_rate
,
ignore_mm_keys
=
ignore_mm_keys
,
)
...
...
@@ -221,3 +238,29 @@ def test_processing_correctness_phi3v(
num_batches
=
num_batches
,
simplify_rate
=
simplify_rate
,
)
def
_drop_mm_kwargs_keys
(
result
:
dict
,
ignore_mm_keys
:
Optional
[
list
[
str
]]
=
None
)
->
dict
:
"""Drop specified keys from result['mm_kwargs'].
This is mainly to avoid doing exact match of audio_features in ultravox.
Args:
result: Result to drop keys from
ignore_mm_keys: List of keys to ignore, e.g. ['audio_features']
"""
if
not
ignore_mm_keys
:
return
result
if
'mm_kwargs'
in
result
:
result
=
copy
.
deepcopy
(
result
)
mm_kwargs
=
result
[
'mm_kwargs'
]
for
key
in
ignore_mm_keys
:
mm_kwargs
.
pop
(
key
,
None
)
for
items
in
mm_kwargs
.
_items_by_modality
.
values
():
for
item
in
items
:
for
key
in
ignore_mm_keys
:
item
.
pop
(
key
,
None
)
return
result
tests/models/registry.py
View file @
80e78d02
...
...
@@ -284,8 +284,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2-VL-2B-Instruct"
),
# noqa: E501
"Qwen2_5_VLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-VL-3B-Instruct"
,
# noqa: E501
min_transformers_version
=
"4.49"
),
# noqa: E501
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_4"
,
extras
=
{
"v0.5"
:
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
},
# noqa: E501
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
,
# noqa: E501
trust_remote_code
=
True
),
# [Encoder-decoder]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
...
...
vllm/model_executor/models/ultravox.py
View file @
80e78d02
...
...
@@ -5,7 +5,7 @@
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
from
typing
import
Any
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
from
typing
import
Any
,
List
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
import
torch.utils.checkpoint
...
...
@@ -44,12 +44,23 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
_AUDIO_PLACEHOLDER_OVERRIDE
=
"<|reserved_special_token_0|>"
_AUDIO_PLACEHOLDER_TOKEN
=
128002
_AUDIO_TOKENS_PER_SECOND
=
6.25
_MAX_ENCODER_BATCH_SIZE
=
16
class
UltravoxAudioFeatureInputs
(
TypedDict
):
type
:
Literal
[
"audio_features"
]
data
:
NestedTensors
"""Shape: `(batch_size, num_audios, 80, M)`"""
"""Shape: `(batch_size, num_chunks, 80, M)`"""
lens
:
NestedTensors
"""
Length of the audio frames. Used for attention mask in WhisperEncoder.
Shape: `(batch_size, num_chunks)`
"""
token_len
:
NestedTensors
"""
Length of the audio tokens. Used for flattening the audio features.
Shape: `(batch_size, num_chunks)`
"""
class
UltravoxAudioEmbeddingInputs
(
TypedDict
):
...
...
@@ -78,6 +89,7 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
# token, thus we override placeholder with a reserved special
# token.
hf_processor
.
audio_token_replacement
=
_AUDIO_PLACEHOLDER_OVERRIDE
hf_processor
.
audio_replacement_token_id
=
_AUDIO_PLACEHOLDER_TOKEN
return
hf_processor
def
get_feature_extractor
(
...
...
@@ -104,7 +116,7 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
max_audio_tokens
=
math
.
ceil
(
feature_extractor
.
chunk_length
*
_AUDIO_TOKENS_PER_SECOND
)
return
{
"audio"
:
max_audio_tokens
}
return
{
"audio"
:
max_audio_tokens
*
_MAX_ENCODER_BATCH_SIZE
}
class
UltravoxDummyInputsBuilder
(
BaseDummyInputsBuilder
[
UltravoxProcessingInfo
]
...
...
@@ -118,7 +130,8 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
feature_extractor
=
self
.
info
.
get_feature_extractor
()
sampling_rate
=
feature_extractor
.
sampling_rate
audio_len
=
feature_extractor
.
chunk_length
*
sampling_rate
audio_len
=
(
feature_extractor
.
chunk_length
*
sampling_rate
*
_MAX_ENCODER_BATCH_SIZE
)
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
mm_data
=
{
...
...
@@ -160,41 +173,38 @@ class UltravoxMultiModalProcessor(
mm_kwargs
=
dict
(
**
mm_kwargs
,
sampling_rate
=
feature_extractor
.
sampling_rate
,
include_audio_num_chunks
=
True
,
)
# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
audio_features
,
audio_token_len
=
[],
[]
shared_outputs
=
{}
for
audio
in
audios
:
# NOTE: Ultravox processor accepts "audio" instead of "audios"
item_processor_data
=
dict
(
**
mm_data
,
audio
=
audio
)
item_processor_data
=
dict
(
**
mm_data
,
audios
=
audios
)
item_
output
s
=
super
().
_call_hf_processor
(
output
=
super
().
_call_hf_processor
(
prompt
=
prompt
,
mm_data
=
item_processor_data
,
mm_kwargs
=
mm_kwargs
,
)
output
[
'audio_features'
]
=
output
.
pop
(
'audio_values'
)
audio_features
.
append
(
item_outputs
.
pop
(
"audio_values"
)[
0
])
audio_token_len
.
append
(
item_outputs
.
pop
(
"audio_token_len"
).
item
())
shared_outputs
=
item_outputs
combined_outputs
=
dict
(
**
shared_outputs
,
audio_features
=
audio_features
,
audio_token_len
=
audio_token_len
,
)
return
BatchFeature
(
combined_outputs
)
return
output
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
num_chunks
=
hf_inputs
.
get
(
'audio_num_chunks'
,
torch
.
zeros
(
0
))
return
dict
(
audio_features
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_token_len
=
MultiModalFieldConfig
.
batched
(
"audio"
),
# to handle longer than 30s audio, each audio might be split
# into multiple chunks as such, their batch dimension can be
# higher than the number of audio samples
audio_features
=
MultiModalFieldConfig
.
flat_from_sizes
(
"audio"
,
num_chunks
),
audio_token_len
=
MultiModalFieldConfig
.
flat_from_sizes
(
"audio"
,
num_chunks
),
audio_lens
=
MultiModalFieldConfig
.
flat_from_sizes
(
"audio"
,
num_chunks
),
# num_chunks can convert audio_chunked to audio batch dimension
audio_num_chunks
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_embeds
=
MultiModalFieldConfig
.
batched
(
"audio"
),
)
...
...
@@ -205,14 +215,23 @@ class UltravoxMultiModalProcessor(
out_mm_kwargs
:
MultiModalKwargs
,
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
tokenizer
=
self
.
info
.
get_tokenizer
()
vocab
=
tokenizer
.
get_vocab
()
replacement_id
=
vocab
[
hf_processor
.
audio_token_replacement
]
# type: ignore
replacement_id
=
hf_processor
.
audio_replacement_token_id
# type: ignore
# Each audio can be split into multiple chunks.
# chunks_start_idx[i] indicates the start index of the chunks
# belonging to the i-th audio.
num_chunks
=
out_mm_kwargs
.
get
(
"audio_num_chunks"
,
torch
.
zeros
(
0
))
chunks_start_idx
:
torch
.
Tensor
=
torch
.
cumsum
(
num_chunks
,
dim
=
0
,
dtype
=
torch
.
int32
)
chunks_start_idx
=
torch
.
cat
(
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
chunks_start_idx
])
def
get_replacement_ultravox
(
item_idx
:
int
):
audio_token_len
=
out_mm_kwargs
[
"audio_token_len"
][
item_idx
]
start
=
chunks_start_idx
[
item_idx
]
end
=
chunks_start_idx
[
item_idx
+
1
]
audio_token_len
=
out_mm_kwargs
[
"audio_token_len"
][
start
:
end
].
sum
()
return
[
replacement_id
]
*
int
(
audio_token_len
)
# type: ignore
return
[
...
...
@@ -304,12 +323,49 @@ class ModifiedWhisperEncoder(WhisperEncoder):
base_model_prefix
=
"model.encoder"
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
config
.
is_decoder
=
False
@
property
def
max_context_length
(
self
):
return
(
self
.
config
.
max_source_positions
*
self
.
conv1
.
stride
[
0
]
*
self
.
conv2
.
stride
[
0
])
def
get_attention_mask_by_audio_len
(
self
,
audio_lens
:
Optional
[
torch
.
Tensor
],
hidden_states
:
torch
.
Tensor
):
"""
Create attention mask based on audio lengths to mask out padding tokens
For each sample in batch:
- Convert raw audio length to feature length after convolutions
- Create bool mask: True for valid positions and False for padding
- Convert to attention mask format expected by transformer layers
(1.0 for positions to attend to, large negative for positions to ignore)
This masking ensures consistent behavior between training and inference
by preventing the model from attending to padding tokens in both cases
"""
if
audio_lens
is
None
:
return
None
audio_feature_len
=
self
.
_get_feat_extract_output_lengths
(
audio_lens
)
max_seq_len
=
hidden_states
.
shape
[
1
]
attention_mask
=
torch
.
arange
(
max_seq_len
,
device
=
hidden_states
.
device
)[
None
,
:].
lt
(
audio_feature_len
.
view
(
-
1
,
1
))
attention_mask
=
self
.
get_extended_attention_mask
(
attention_mask
,
None
,
dtype
=
hidden_states
.
dtype
,
)
return
attention_mask
def
forward
(
self
,
input_features
,
input_features
:
torch
.
Tensor
,
audio_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
expected_seq_length
=
(
self
.
config
.
max_source_positions
*
self
.
conv1
.
stride
[
0
]
*
self
.
conv2
.
stride
[
0
])
expected_seq_length
=
self
.
max_context_length
if
input_features
.
shape
[
-
1
]
>
expected_seq_length
:
raise
ValueError
(
f
"Whisper expects the mel input features to be of length "
...
...
@@ -328,10 +384,13 @@ class ModifiedWhisperEncoder(WhisperEncoder):
p
=
self
.
dropout
,
training
=
self
.
training
)
attention_mask
=
self
.
get_attention_mask_by_audio_len
(
audio_lens
,
hidden_states
)
for
encoder_layer
in
self
.
layers
:
layer_outputs
=
encoder_layer
(
hidden_states
,
None
,
attention_mask
,
layer_head_mask
=
None
,
)
...
...
@@ -409,17 +468,34 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
)
def
_audio_features_to_embeddings
(
self
,
input_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
audio_input
=
input_features
.
to
(
self
.
audio_tower
.
dtype
)
audio_features
=
self
.
audio_tower
(
audio_input
)
audio_features
=
audio_features
.
to
(
self
.
audio_tower
.
dtype
)
audio_embeddings
=
self
.
multi_modal_projector
(
audio_features
)
self
,
input_features
:
torch
.
Tensor
,
audio_lens
:
torch
.
Tensor
)
->
torch
.
Tensor
:
audio_features
=
input_features
.
to
(
self
.
audio_tower
.
dtype
)
batch_size
=
audio_features
.
size
(
0
)
audio_embeddings
=
[]
# Process audio features in batches to keep memory usage predictable
for
start
in
range
(
0
,
batch_size
,
_MAX_ENCODER_BATCH_SIZE
):
end
=
min
(
start
+
_MAX_ENCODER_BATCH_SIZE
,
batch_size
)
# Process through audio tower
batch_features
=
self
.
audio_tower
(
audio_features
[
start
:
end
],
audio_lens
[
start
:
end
])
batch_features
=
batch_features
.
to
(
self
.
audio_tower
.
dtype
)
# Process through projector
batch_embeddings
=
self
.
multi_modal_projector
(
batch_features
)
audio_embeddings
.
append
(
batch_embeddings
)
# Concatenate results
audio_embeddings
=
torch
.
cat
(
audio_embeddings
,
dim
=
0
)
return
audio_embeddings
def
_parse_and_validate_audio_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
UltravoxAudioInputs
]:
audio_features
=
kwargs
.
pop
(
"audio_features"
,
None
)
audio_embeds
=
kwargs
.
pop
(
"audio_embeds"
,
None
)
audio_lens
=
kwargs
.
pop
(
"audio_lens"
,
None
)
audio_token_len
=
kwargs
.
pop
(
"audio_token_len"
,
None
)
if
audio_features
is
None
and
audio_embeds
is
None
:
return
None
...
...
@@ -430,7 +506,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
f
"Got type:
{
type
(
audio_features
)
}
"
)
return
UltravoxAudioFeatureInputs
(
type
=
"audio_features"
,
data
=
audio_features
)
data
=
audio_features
,
lens
=
audio_lens
,
token_len
=
audio_token_len
)
if
audio_embeds
is
not
None
:
if
not
isinstance
(
audio_embeds
,
(
torch
.
Tensor
,
list
)):
...
...
@@ -447,34 +525,34 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
if
audio_input
[
"type"
]
==
"audio_embeds"
:
return
audio_input
[
"data"
]
audio_features
=
audio_input
[
"data"
]
if
isinstance
(
audio_features
,
torch
.
Tensor
):
# Combine the B and N dimensions for the encoder/projector
flattened
=
flatten_bn
(
audio_features
)
flattened_embeddings
=
self
.
_audio_features_to_embeddings
(
flattened
)
# Restore the original dimensions
embeddings
=
flattened_embeddings
.
unflatten
(
0
,
audio_features
.
shape
[:
2
])
return
embeddings
result
=
[]
# TODO: Batch heterogeneous tensors through the encoder/projector
for
audio_features_item
in
audio_features
:
if
isinstance
(
audio_features_item
,
torch
.
Tensor
):
result
.
append
(
self
.
_audio_features_to_embeddings
(
audio_features_item
))
# Pad and concatenate audio features
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
audio_features
=
pad_and_concat_to_dim3
(
audio_input
[
"data"
])
if
isinstance
(
audio_input
[
'lens'
],
list
):
# [B1, B2] -> [B1+B2]
audio_lens
=
torch
.
cat
(
audio_input
[
'lens'
])
audio_token_len
=
torch
.
cat
(
audio_input
[
'token_len'
])
else
:
embeddings
=
[
# Add a batch dimension to embed it, then remove it.
self
.
_audio_features_to_embeddings
(
tensor
.
unsqueeze
(
0
)
).
squeeze
(
0
)
for
tensor
in
audio_features_item
]
result
.
append
(
embeddings
)
audio_lens
=
flatten_bn
(
audio_input
[
'lens'
])
audio_token_len
=
flatten_bn
(
audio_input
[
'token_len'
])
embeddings
=
self
.
_audio_features_to_embeddings
(
audio_features
,
audio_lens
)
# We should flatten and concatenate embeddings based on token lengths
# For example, with token_len = [4, 2, 3], flattened_embeddings will be
# concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
return
result
# Create a mask of valid indices based on token lengths
max_len
=
embeddings
.
shape
[
1
]
indices
=
torch
.
arange
(
max_len
,
device
=
embeddings
.
device
).
expand
(
embeddings
.
shape
[
0
],
-
1
)
mask
=
indices
<
audio_token_len
[:,
None
]
# Apply mask and flatten
flattened_embeddings
=
embeddings
[
mask
]
return
flattened_embeddings
def
get_multimodal_embeddings
(
self
,
**
kwargs
...
...
@@ -521,7 +599,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
with the `input_ids`.
Args:
audio_features: A batch of audio inputs [B, N, 80, M].
audio_features: A batch of audio input chunks [B, N, 80, M].
audio_lens: Length of audio frames for each audio chunk [B].
audio_token_len: Length of audio tokens for each audio chunk [B'].
Note: batch dim is different from batch dim in audio chunks.
"""
if
intermediate_tensors
is
not
None
:
...
...
@@ -560,3 +642,31 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
loader
=
AutoWeightsLoader
(
self
,
ignore_unexpected_prefixes
=
[
"audio_tower."
])
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
def
pad_and_concat_to_dim3
(
features
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
],
List
[
List
[
torch
.
Tensor
]]]
)
->
torch
.
Tensor
:
"""
Pad and concatenate a list of tensors.
output:
Tensor of shape [B, C, M] where M is the maximum length of the input
tensors, B is the sum of the batch sizes of the input tensors.
C must be the same for all input tensors.
"""
if
isinstance
(
features
,
torch
.
Tensor
):
if
features
.
ndim
>
3
:
# Flatten [B, N, 80, M] -> [B * N, 80, M]
features
=
flatten_bn
(
features
)
return
features
features
=
[
pad_and_concat_to_dim3
(
f
)
for
f
in
features
]
max_len
=
max
(
f
.
shape
[
-
1
]
for
f
in
features
)
# Ensure all features have dim=3
features
=
[
f
.
view
(
-
1
,
*
f
.
shape
[
-
2
:])
for
f
in
features
]
# Pad and oncatenate:
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
features
=
[
F
.
pad
(
f
,
(
0
,
max_len
-
f
.
shape
[
-
1
]))
for
f
in
features
]
return
torch
.
cat
(
features
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment