Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
68be1d3c
Unverified
Commit
68be1d3c
authored
Apr 18, 2024
by
Yoach Lacombe
Committed by
GitHub
Apr 18, 2024
Browse files
fix Parameter dtype in audio models (#30310)
parent
79132145
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
33 additions
and
33 deletions
+33
-33
src/transformers/models/data2vec/modeling_data2vec_audio.py
src/transformers/models/data2vec/modeling_data2vec_audio.py
+3
-3
src/transformers/models/hubert/modeling_hubert.py
src/transformers/models/hubert/modeling_hubert.py
+3
-3
src/transformers/models/sew/modeling_sew.py
src/transformers/models/sew/modeling_sew.py
+3
-3
src/transformers/models/sew_d/modeling_sew_d.py
src/transformers/models/sew_d/modeling_sew_d.py
+3
-3
src/transformers/models/speecht5/modeling_speecht5.py
src/transformers/models/speecht5/modeling_speecht5.py
+3
-3
src/transformers/models/unispeech/modeling_unispeech.py
src/transformers/models/unispeech/modeling_unispeech.py
+3
-3
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
...ansformers/models/unispeech_sat/modeling_unispeech_sat.py
+3
-3
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
+3
-3
src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
...ansformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
+3
-3
src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
.../models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
+3
-3
src/transformers/models/wavlm/modeling_wavlm.py
src/transformers/models/wavlm/modeling_wavlm.py
+3
-3
No files found.
src/transformers/models/data2vec/modeling_data2vec_audio.py
View file @
68be1d3c
...
...
@@ -822,7 +822,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
encoder
=
Data2VecAudioEncoder
(
config
)
...
...
@@ -858,7 +858,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
...
...
@@ -868,7 +868,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/hubert/modeling_hubert.py
View file @
68be1d3c
...
...
@@ -974,7 +974,7 @@ class HubertModel(HubertPreTrainedModel):
self
.
feature_projection
=
HubertFeatureProjection
(
config
)
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
if
config
.
do_stable_layer_norm
:
self
.
encoder
=
HubertEncoderStableLayerNorm
(
config
)
...
...
@@ -1005,7 +1005,7 @@ class HubertModel(HubertPreTrainedModel):
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
...
...
@@ -1015,7 +1015,7 @@ class HubertModel(HubertPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/sew/modeling_sew.py
View file @
68be1d3c
...
...
@@ -834,7 +834,7 @@ class SEWModel(SEWPreTrainedModel):
self
.
feature_dropout
=
nn
.
Dropout
(
config
.
feat_proj_dropout
)
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
encoder
=
SEWEncoder
(
config
)
...
...
@@ -862,7 +862,7 @@ class SEWModel(SEWPreTrainedModel):
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
...
...
@@ -872,7 +872,7 @@ class SEWModel(SEWPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/sew_d/modeling_sew_d.py
View file @
68be1d3c
...
...
@@ -1360,7 +1360,7 @@ class SEWDModel(SEWDPreTrainedModel):
self
.
feature_dropout
=
nn
.
Dropout
(
config
.
feat_proj_dropout
)
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
encoder
=
SEWDEncoder
(
config
)
...
...
@@ -1388,7 +1388,7 @@ class SEWDModel(SEWDPreTrainedModel):
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
...
...
@@ -1398,7 +1398,7 @@ class SEWDModel(SEWDPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/speecht5/modeling_speecht5.py
View file @
68be1d3c
...
...
@@ -517,7 +517,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
pos_conv_embed
=
SpeechT5PositionalConvEmbedding
(
config
)
self
.
pos_sinusoidal_embed
=
SpeechT5SinusoidalPositionalEmbedding
(
...
...
@@ -616,7 +616,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
...
...
@@ -626,7 +626,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/unispeech/modeling_unispeech.py
View file @
68be1d3c
...
...
@@ -1090,7 +1090,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
self
.
feature_projection
=
UniSpeechFeatureProjection
(
config
)
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
if
config
.
do_stable_layer_norm
:
self
.
encoder
=
UniSpeechEncoderStableLayerNorm
(
config
)
...
...
@@ -1121,7 +1121,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
...
...
@@ -1131,7 +1131,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
View file @
68be1d3c
...
...
@@ -1108,7 +1108,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
self
.
feature_extractor
=
UniSpeechSatFeatureEncoder
(
config
)
self
.
feature_projection
=
UniSpeechSatFeatureProjection
(
config
)
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
if
config
.
do_stable_layer_norm
:
self
.
encoder
=
UniSpeechSatEncoderStableLayerNorm
(
config
)
...
...
@@ -1139,7 +1139,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
...
...
@@ -1149,7 +1149,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/wav2vec2/modeling_wav2vec2.py
View file @
68be1d3c
...
...
@@ -1445,7 +1445,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
if
config
.
do_stable_layer_norm
:
self
.
encoder
=
Wav2Vec2EncoderStableLayerNorm
(
config
)
...
...
@@ -1496,7 +1496,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
...
...
@@ -1506,7 +1506,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
View file @
68be1d3c
...
...
@@ -1053,7 +1053,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
encoder
=
Wav2Vec2BertEncoder
(
config
)
...
...
@@ -1087,7 +1087,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
...
...
@@ -1097,7 +1097,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
View file @
68be1d3c
...
...
@@ -1235,7 +1235,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
encoder
=
Wav2Vec2ConformerEncoder
(
config
)
...
...
@@ -1273,7 +1273,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
...
...
@@ -1283,7 +1283,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/wavlm/modeling_wavlm.py
View file @
68be1d3c
...
...
@@ -1107,7 +1107,7 @@ class WavLMModel(WavLMPreTrainedModel):
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
if
config
.
do_stable_layer_norm
:
self
.
encoder
=
WavLMEncoderStableLayerNorm
(
config
)
...
...
@@ -1158,7 +1158,7 @@ class WavLMModel(WavLMPreTrainedModel):
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
...
...
@@ -1168,7 +1168,7 @@ class WavLMModel(WavLMPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment