Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
03fe18ae
Unverified
Commit
03fe18ae
authored
Mar 08, 2025
by
Isotr0py
Committed by
GitHub
Mar 08, 2025
Browse files
[VLM] Add TP support for Phi-4-MM (#14453)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
cb8bdfad
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
295 deletions
+50
-295
examples/offline_inference/audio_language.py
examples/offline_inference/audio_language.py
+1
-0
vllm/model_executor/models/phi4mm.py
vllm/model_executor/models/phi4mm.py
+24
-49
vllm/model_executor/models/phi4mm_audio.py
vllm/model_executor/models/phi4mm_audio.py
+18
-150
vllm/model_executor/models/phi4mm_utils.py
vllm/model_executor/models/phi4mm_utils.py
+7
-96
No files found.
examples/offline_inference/audio_language.py
View file @
03fe18ae
...
...
@@ -77,6 +77,7 @@ def run_phi4mm(questions: str, audio_count: int):
enable_lora
=
True
,
max_lora_rank
=
320
,
lora_extra_vocab_size
=
0
,
limit_mm_per_prompt
=
{
"audio"
:
audio_count
},
)
lora_request
=
LoRARequest
(
"speech"
,
1
,
speech_lora_path
)
# To maintain code compatibility in this script, we add LoRA here.
...
...
vllm/model_executor/models/phi4mm.py
View file @
03fe18ae
...
...
@@ -15,7 +15,7 @@ from transformers import PretrainedConfig
from
transformers.utils
import
logging
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
)
from
vllm.inputs.data
import
TokenInputs
,
token_inputs
...
...
@@ -34,7 +34,7 @@ from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
from
.phi4mm_audio
import
AudioEmbedding
from
.utils
import
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
from
.vision_siglip_navit
import
get_siglip_vision_model
# <|endoftext10|> (see vocab.json in hf model)
...
...
@@ -352,12 +352,6 @@ class Phi4MMImageEncoder(nn.Module):
# n_embed or hidden_size
hidden_size
=
config
.
n_embd
if
hasattr
(
config
,
'n_embd'
)
else
config
.
hidden_size
if
hasattr
(
config
,
'embd_pdrop'
)
or
hasattr
(
config
,
'embed_pdrop'
):
embd_drop
=
config
.
embd_pdrop
if
hasattr
(
config
,
'embd_pdrop'
)
else
config
.
embed_pdrop
self
.
drop
=
nn
.
Dropout
(
embd_drop
)
else
:
self
.
drop
=
None
# layer_idx to output the img features
if
isinstance
(
config
.
img_processor
,
dict
):
...
...
@@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
],
}
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
"base_layer."
:
""
,
},
orig_to_new_prefix
=
{
"model.embed_tokens_extend.audio_embed.audio_projection.vision."
:
"embed_tokens_extend.audio_projection_for_vision."
,
"model.embed_tokens_extend.audio_embed.audio_projection.speech."
:
"embed_tokens_extend.audio_projection."
,
"model.embed_tokens_extend.audio_embed."
:
"embed_tokens_extend."
,
"model.embed_tokens_extend.image_embed."
:
"vision_encoder."
,
},
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
...
...
@@ -1445,8 +1453,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
self
.
lora_config
=
lora_config
# Tensor/Pipeline parallel not supported for now.
assert
get_tensor_model_parallel_world_size
(
)
==
1
,
"tensor parallel is not supported"
assert
get_pp_group
(
).
world_size
==
1
,
"pipeline parallel is not supported"
...
...
@@ -1686,44 +1692,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
)
return
merged_embeds
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
None
:
weights
=
{
name
:
weight
for
name
,
weight
in
weights
}
adjusted_weights
=
{}
for
name
,
weight
in
weights
.
items
():
# NOTE vision-speech tasks use a separate projection layer
audio_proj_4v
=
\
"model.embed_tokens_extend.audio_embed.audio_projection.vision"
if
name
.
startswith
(
audio_proj_4v
):
name
=
name
.
replace
(
audio_proj_4v
,
"embed_tokens_extend.audio_projection_for_vision"
)
name
=
(
name
.
replace
(
"model.embed_tokens_extend.audio_embed."
\
"audio_projection.speech."
,
"embed_tokens_extend.audio_projection."
,
).
replace
(
"model.embed_tokens_extend.audio_embed."
,
"embed_tokens_extend."
,
).
replace
(
"model.embed_tokens_extend.image_embed."
,
"vision_encoder."
))
# NOTE: this is deal with LoRA injection, where `base_layer`
# remains as the original layer in the model
if
name
.
endswith
(
".base_layer.weight"
):
name
=
name
.
replace
(
".base_layer.weight"
,
".weight"
)
adjusted_weights
[
name
]
=
weight
missing_keys
,
unexpected_keys
=
self
.
load_state_dict
(
adjusted_weights
,
strict
=
False
)
logger
.
debug
(
"*** missing keys:"
)
for
key
in
missing_keys
:
logger
.
debug
(
key
)
logger
.
debug
(
"**** unexpected keys:"
)
for
key
in
unexpected_keys
:
logger
.
debug
(
key
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -1796,6 +1764,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
None
:
weights
=
((
name
,
data
)
for
name
,
data
in
weights
if
"lora"
not
in
name
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
...
...
vllm/model_executor/models/phi4mm_audio.py
View file @
03fe18ae
...
...
@@ -6,69 +6,26 @@
#!/usr/bin/env python3
import
abc
import
math
from
functools
import
partial
from
typing
import
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
List
,
Literal
,
Optional
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
nn
from
torch.distributed.algorithms._checkpoint.checkpoint_wrapper
import
(
Checkpoint
Impl
,
CheckpointWrapper
,
checkpoint_wrapper
,
offload_w
rapper
)
Checkpoint
W
rapper
)
from
torch.distributed.fsdp.fully_sharded_data_parallel
import
(
FullyShardedDataParallel
)
from
torch.utils.checkpoint
import
checkpoint
from
transformers
import
PretrainedConfig
from
vllm.model_executor.models.phi4mm_utils
import
(
AbsolutePositionalEncoding
,
ConvModule
,
FeedForward
,
MeanVarianceNormLayer
,
MultiHeadedAttention
,
NemoConvSubsampling
,
T5RelativeAttentionLogitBias
,
adaptive_enc_mask
,
attn_checkpointing
,
embedding_checkpoint_wrapper
,
get_offset
,
repeat
,
unfold_tensor
,
validate_checkpointing_config
)
MultiHeadedAttention
,
MultiSequential
,
NemoConvSubsampling
,
T5RelativeAttentionLogitBias
,
adaptive_enc_mask
,
get_offset
,
unfold_tensor
)
_AUDIO_PLACEHOLDER_TOKEN_ID
=
200011
# <|endoftext11|>
def
encoder_checkpoint_wrapper
(
activation_checkpointing
:
Union
[
str
,
Dict
],
layer_cls
:
type
,
idx
:
int
=
0
,
)
->
Callable
:
"""return encoder activation checkpoint wrapper"""
validate_checkpointing_config
(
activation_checkpointing
)
if
isinstance
(
activation_checkpointing
,
str
):
if
activation_checkpointing
:
if
activation_checkpointing
==
"offload"
:
return
offload_wrapper
return
partial
(
checkpoint_wrapper
)
return
lambda
x
:
x
if
isinstance
(
activation_checkpointing
,
dict
):
target_layer_cls
=
activation_checkpointing
.
get
(
"module"
,
"transformer"
)
if
target_layer_cls
.
lower
()
==
"transformer"
:
target_layer_cls
=
(
"EncoderLayer"
,
"ConformerEncoderLayer"
,
)
elif
target_layer_cls
.
lower
()
==
"attention"
:
target_layer_cls
=
(
"MultiHeadedAttention"
,
"MultiHeadAttention"
)
checkpointing_interval
=
activation_checkpointing
.
get
(
"interval"
,
1
)
offloading
=
activation_checkpointing
.
get
(
"offload"
,
False
)
impl
=
(
CheckpointImpl
.
REENTRANT
if
activation_checkpointing
.
get
(
"reentrant"
,
True
)
else
CheckpointImpl
.
NO_REENTRANT
)
if
(
idx
%
checkpointing_interval
==
0
and
layer_cls
.
__name__
in
target_layer_cls
):
if
offloading
:
return
offload_wrapper
return
partial
(
checkpoint_wrapper
,
checkpoint_impl
=
impl
)
return
lambda
x
:
x
raise
ValueError
(
"Invalid activation_checkpointing config"
)
class
ConformerEncoderLayer
(
nn
.
Module
):
"""ConformerEncoder Layer module.
for more details see conformer paper:
...
...
@@ -208,10 +165,7 @@ class ConformerEncoderLayer(nn.Module):
bias_in_glu
=
bias_in_glu
,
)
self
.
self_attn
=
encoder_checkpoint_wrapper
(
activation_checkpointing
,
MultiHeadedAttention
,
)(
MultiHeadedAttention
(
self
.
self_attn
=
MultiHeadedAttention
(
n_head
,
d_model
,
dropout_rate
,
...
...
@@ -221,7 +175,7 @@ class ConformerEncoderLayer(nn.Module):
use_pt_scaled_dot_product_attention
=
use_pt_scaled_dot_product_attention
,
group_size
=
attn_group_sizes
,
)
)
)
self
.
conv
=
ConvModule
(
d_model
,
ext_pw_out_channel
,
...
...
@@ -441,24 +395,6 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
else
:
raise
NotImplementedError
def
post_init
(
self
,
init_model_config
):
pretrained_speech_encoder_path
=
init_model_config
.
get
(
"pretrained_speech_encoder_path"
,
None
)
if
pretrained_speech_encoder_path
:
model_state
=
torch
.
load
(
pretrained_speech_encoder_path
,
map_location
=
"cpu"
)
encoder_state_dict
=
{}
for
k
,
v
in
model_state
.
items
():
if
"encoder."
in
k
:
tmp_k
=
k
.
replace
(
"encoder."
,
""
)
encoder_state_dict
[
tmp_k
]
=
v
if
hasattr
(
self
,
"encoder_embedding"
):
del
self
.
encoder_embedding
self
.
load_state_dict
(
encoder_state_dict
)
if
not
hasattr
(
self
,
"encoder_embedding"
):
self
.
encoder_embedding
=
MeanVarianceNormLayer
(
self
.
encoder_embedding_config
[
"input_size"
])
...
...
@@ -558,14 +494,6 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
# Create mask matrix for streaming
# S stores start index. if chunksize is 18, s is [0,18,36,....]
chunk_start_idx
=
np
.
arange
(
0
,
seq_len
,
chunk_size_train_eff
)
# avoid randomness when run evaluation or decoding
if
self
.
training
and
np
.
random
.
rand
()
>
0.5
:
# Either first or last chunk is not complete.
# If only the last one is not complete, EOS is not effective
chunk_start_idx
=
seq_len
-
chunk_start_idx
chunk_start_idx
=
chunk_start_idx
[::
-
1
]
chunk_start_idx
=
chunk_start_idx
[:
-
1
]
chunk_start_idx
=
np
.
insert
(
chunk_start_idx
,
0
,
0
)
enc_streaming_mask
=
(
adaptive_enc_mask
(
seq_len
,
chunk_start_idx
,
...
...
@@ -883,23 +811,17 @@ class ConformerEncoder(TransformerEncoderBase):
self
.
num_blocks
=
num_blocks
self
.
num_lang
=
num_lang
self
.
kernel_size
=
kernel_size
self
.
embed
=
embedding_checkpoint_wrapper
(
activation_checkpointing
)(
self
.
embed
)
self
.
replication_pad_for_subsample_embedding
:
bool
=
(
replication_pad_for_subsample_embedding
)
assert
(
self
.
num_heads
%
attention_group_size
==
0
),
"attention_group_size must divide n_head"
self
.
num_heads_k
=
self
.
num_heads
//
attention_group_size
self
.
encoders
=
repeat
(
num_blocks
,
lambda
i
:
encoder_checkpoint_wrapper
(
activation_checkpointing
,
ConformerEncoderLayer
,
i
)
(
ConformerEncoderLayer
(
self
.
encoders
=
MultiSequential
(
*
[
ConformerEncoderLayer
(
d_model
=
attention_dim
,
ext_pw_out_channel
=
ext_pw_out_channel
,
depthwise_seperable_out_channel
=
depthwise_seperable_out_channel
,
depthwise_seperable_out_channel
=
depthwise_seperable_out_channel
,
depthwise_multiplier
=
depthwise_multiplier
,
n_head
=
attention_heads
,
d_ffn
=
linear_units
,
...
...
@@ -916,14 +838,13 @@ class ConformerEncoder(TransformerEncoderBase):
bias_in_glu
=
bias_in_glu
,
linear_glu_in_convm
=
linear_glu_in_convm
,
attention_glu_type
=
attention_glu_type
,
activation_checkpointing
=
attn_checkpointing
(
activation_checkpointing
,
i
),
activation_checkpointing
=
activation_checkpointing
,
export
=
export
,
use_pt_scaled_dot_product_attention
=
use_pt_scaled_dot_product_attention
,
attn_group_sizes
=
attention_group_size
,
))
,
)
)
for
_
in
range
(
num_blocks
)
]
)
self
.
extra_layer_output_idx
=
extra_layer_output_idx
self
.
extra_multi_layer_output_idxs
=
extra_multi_layer_output_idxs
# Make a zeros scalar we can use in get_initial_state to determine
...
...
@@ -1041,9 +962,6 @@ class ConformerEncoder(TransformerEncoderBase):
return
input_tensor
,
masks
# , layer_emb
def
gradient_checkpointing_enable
(
self
):
pass
class
WindowQformer
(
nn
.
Module
):
"""Window-level Qformer"""
...
...
@@ -1077,13 +995,6 @@ class WindowQformer(nn.Module):
self
.
after_norm
=
(
nn
.
LayerNorm
(
attention_dim
,
eps
=
1e-12
)
if
normalize_before
else
None
)
self
.
window_size
=
window_size
self
.
gradient_checkpointing_enable
=
False
def
enable_gradient_checkpointing
(
self
):
self
.
gradient_checkpointing_enable
=
True
def
disable_gradient_checkpointing
(
self
):
self
.
gradient_checkpointing_enable
=
False
def
forward
(
self
,
audio_embed
,
mask
,
embed_len
=
None
):
"""forward decoder"""
...
...
@@ -1111,16 +1022,6 @@ class WindowQformer(nn.Module):
# NT' x 1 x D
q
=
self
.
queries
.
expand
(
bsz
*
slen
,
-
1
,
-
1
)
for
layer
in
self
.
decoders
:
if
self
.
gradient_checkpointing_enable
and
self
.
training
:
q
=
checkpoint
(
layer
.
__call__
,
q
,
embed_chunk
,
None
,
mask
,
use_reentrant
=
True
,
)
else
:
q
=
layer
(
tgt
=
q
,
memory
=
embed_chunk
,
tgt_mask
=
None
,
...
...
@@ -1147,13 +1048,6 @@ class AudioEmbedding(nn.Module):
hidden_size
=
(
config
.
n_embd
if
hasattr
(
config
,
"n_embd"
)
else
config
.
hidden_size
)
if
hasattr
(
config
,
"embd_pdrop"
)
or
hasattr
(
config
,
"embed_pdrop"
):
embd_drop
=
(
config
.
embd_pdrop
if
hasattr
(
config
,
"embd_pdrop"
)
else
config
.
embed_pdrop
)
self
.
drop
=
nn
.
Dropout
(
embd_drop
)
else
:
self
.
drop
=
None
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
audio_dim_out
=
(
...
...
@@ -1167,12 +1061,6 @@ class AudioEmbedding(nn.Module):
assert
encoder_config
is
not
None
self
.
encoder
=
ConformerEncoder
(
**
encoder_config
)
# fake initialization, create encoder_embedding layer only so that
# in decoding, all parameters can be loaded in
# from_pretrained_function in training, we do post init after
# from_pretrained function to make sure the correct initialization
self
.
encoder
.
post_init
({})
audio_dim_out
=
encoder_config
[
"attention_dim"
]
n_mels
=
encoder_config
[
"input_size"
]
else
:
...
...
@@ -1221,14 +1109,6 @@ class AudioEmbedding(nn.Module):
else
:
self
.
conv_ds
=
None
enable_gradient_checkpointing
=
kwargs
.
get
(
"enable_gradient_checkpointing"
,
False
)
if
enable_gradient_checkpointing
:
self
.
encoder
.
gradient_checkpointing_enable
()
if
self
.
qformer
:
self
.
qformer
.
enable_gradient_checkpointing
()
projection_cls
=
kwargs
.
get
(
"projection_cls"
,
"linear"
)
if
projection_cls
==
"linear"
:
self
.
audio_projection
=
nn
.
Linear
(
audio_dim_out
,
hidden_size
)
...
...
@@ -1388,16 +1268,4 @@ class AudioEmbedding(nn.Module):
hidden_states
.
dtype
).
to
(
hidden_states
.
device
))
idx
+=
cnt
else
:
if
self
.
training
:
# hidden_states[:, 0:img_set_tensor.shape[0]] =
# hidden_states[:, 0:img_set_tensor.shape[0]] +
# 0 * img_set_tensor.to(hidden_states.dtype)
# .to(hidden_states.device)
hidden_states
[:,
0
:
1
]
=
hidden_states
[:,
0
:
1
]
+
\
0
*
audio_set_tensor
[:,
0
:
1
].
to
(
hidden_states
.
dtype
)
\
.
to
(
hidden_states
.
device
)
if
self
.
drop
is
not
None
:
hidden_states
=
self
.
drop
(
hidden_states
)
return
hidden_states
vllm/model_executor/models/phi4mm_utils.py
View file @
03fe18ae
...
...
@@ -5,14 +5,11 @@
# but implemented by the Phi-Speech team
#!/usr/bin/env python3
import
math
from
functools
import
partial
from
typing
import
Callable
,
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
nn
from
torch.distributed.algorithms._checkpoint.checkpoint_wrapper
import
(
CheckpointImpl
,
checkpoint_wrapper
,
offload_wrapper
)
class
Block
(
nn
.
Module
):
...
...
@@ -873,10 +870,8 @@ class MeanVarianceNormLayer(nn.Module):
def
__init__
(
self
,
input_size
):
super
().
__init__
()
self
.
input_size
=
input_size
self
.
register_buffer
(
"global_mean"
,
torch
.
zeros
(
input_size
))
self
.
register_buffer
(
"global_invstd"
,
torch
.
ones
(
input_size
))
self
.
global_mean
:
Optional
[
Tensor
]
self
.
global_invstd
:
Optional
[
Tensor
]
self
.
global_mean
=
nn
.
Parameter
(
torch
.
zeros
(
input_size
))
self
.
global_invstd
=
nn
.
Parameter
(
torch
.
ones
(
input_size
))
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
"""MeanVarianceNormLayer Forward
...
...
@@ -1023,17 +1018,6 @@ class CausalConv2D(nn.Conv2d):
self
,
x
,
):
if
self
.
training
:
x
=
F
.
pad
(
x
,
pad
=
(
self
.
_left_padding
,
self
.
_right_padding
,
self
.
_left_padding
,
self
.
_right_padding
,
),
)
else
:
x
=
F
.
pad
(
x
,
pad
=
(
self
.
_left_padding
,
self
.
_right_padding
,
0
,
0
),
...
...
@@ -1840,68 +1824,6 @@ class MultiHeadedAttention(nn.Module):
return
self
.
linear_out
(
x
)
# (batch, time1, d_model)
def
validate_checkpointing_config
(
activation_checkpointing
):
"""validate activation checkpointing configuration"""
if
isinstance
(
activation_checkpointing
,
str
):
assert
activation_checkpointing
in
(
""
,
"checkpoint"
,
"offload"
,
),
"activation_checkpointing has to be a dict or a str in "
\
"('', 'checkpoint', 'offload')."
elif
isinstance
(
activation_checkpointing
,
dict
):
assert
activation_checkpointing
.
get
(
"module"
,
"transformer"
)
in
(
"transformer"
,
"attention"
,
),
"module in activation_checkpointing has to be in "
\
"('transformer', 'attention')."
else
:
raise
ValueError
(
"activation_checkpointing has to be a str"
\
" or dict."
)
def
embedding_checkpoint_wrapper
(
activation_checkpointing
:
Union
[
str
,
Dict
],
)
->
Callable
:
"""return encoder embedding activation checkpoint wrapper"""
validate_checkpointing_config
(
activation_checkpointing
)
if
isinstance
(
activation_checkpointing
,
str
):
if
activation_checkpointing
:
if
activation_checkpointing
==
"offload"
:
return
offload_wrapper
return
partial
(
checkpoint_wrapper
)
return
lambda
x
:
x
if
isinstance
(
activation_checkpointing
,
dict
):
enabled
=
activation_checkpointing
.
get
(
"embed"
,
False
)
if
enabled
:
offloading
=
activation_checkpointing
.
get
(
"offload"
,
False
)
if
offloading
:
return
offload_wrapper
impl
=
(
CheckpointImpl
.
REENTRANT
if
activation_checkpointing
.
get
(
"reentrant"
,
False
)
else
CheckpointImpl
.
NO_REENTRANT
)
return
partial
(
checkpoint_wrapper
,
checkpoint_impl
=
impl
)
return
lambda
x
:
x
raise
ValueError
(
"Invalid activation_checkpointing config"
)
def
attn_checkpointing
(
activation_checkpointing
:
Union
[
str
,
Dict
],
i
)
->
Union
[
str
,
Dict
]:
"""return activation checkpointing config for attention layer"""
if
isinstance
(
activation_checkpointing
,
str
):
return
""
if
isinstance
(
activation_checkpointing
,
dict
):
target_layer_cls
=
activation_checkpointing
.
get
(
"module"
,
"transformer"
)
checkpointing_interval
=
activation_checkpointing
.
get
(
"interval"
,
1
)
if
target_layer_cls
==
"attention"
and
i
%
checkpointing_interval
==
0
:
return
activation_checkpointing
return
""
raise
ValueError
(
"Invalid activation_checkpointing config"
)
class
MultiSequential
(
torch
.
nn
.
Sequential
):
"""Multi-input multi-output torch.nn.Sequential"""
...
...
@@ -1913,17 +1835,6 @@ class MultiSequential(torch.nn.Sequential):
return
args
def
repeat
(
repeat_num
,
module_gen_fn
):
"""repeat module N times
:param int repeat_num: repeat time
:param function module_gen_fn: function to generate module
:return: repeated modules
:rtype: MultiSequential
"""
return
MultiSequential
(
*
[
module_gen_fn
(
i
)
for
i
in
range
(
repeat_num
)])
def
get_offset
(
input_layer
:
str
,
time_reduction
:
int
):
"""Get an offset. We will use the offset for determining #frames of a
subsampled feature.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment