Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
68be1d3c
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "abca1741cf065749c44ef5d77f6f632c50beb070"
Unverified
Commit
68be1d3c
authored
Apr 18, 2024
by
Yoach Lacombe
Committed by
GitHub
Apr 18, 2024
Browse files
fix Parameter dtype in audio models (#30310)
parent
79132145
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
33 additions
and
33 deletions
+33
-33
src/transformers/models/data2vec/modeling_data2vec_audio.py
src/transformers/models/data2vec/modeling_data2vec_audio.py
+3
-3
src/transformers/models/hubert/modeling_hubert.py
src/transformers/models/hubert/modeling_hubert.py
+3
-3
src/transformers/models/sew/modeling_sew.py
src/transformers/models/sew/modeling_sew.py
+3
-3
src/transformers/models/sew_d/modeling_sew_d.py
src/transformers/models/sew_d/modeling_sew_d.py
+3
-3
src/transformers/models/speecht5/modeling_speecht5.py
src/transformers/models/speecht5/modeling_speecht5.py
+3
-3
src/transformers/models/unispeech/modeling_unispeech.py
src/transformers/models/unispeech/modeling_unispeech.py
+3
-3
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
...ansformers/models/unispeech_sat/modeling_unispeech_sat.py
+3
-3
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
+3
-3
src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
...ansformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
+3
-3
src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
.../models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
+3
-3
src/transformers/models/wavlm/modeling_wavlm.py
src/transformers/models/wavlm/modeling_wavlm.py
+3
-3
No files found.
src/transformers/models/data2vec/modeling_data2vec_audio.py
View file @
68be1d3c
...
@@ -822,7 +822,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
...
@@ -822,7 +822,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
# model only needs masking vector if mask prob is > 0.0
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
encoder
=
Data2VecAudioEncoder
(
config
)
self
.
encoder
=
Data2VecAudioEncoder
(
config
)
...
@@ -858,7 +858,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
...
@@ -858,7 +858,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
if
mask_time_indices
is
not
None
:
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
...
@@ -868,7 +868,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
...
@@ -868,7 +868,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/hubert/modeling_hubert.py
View file @
68be1d3c
...
@@ -974,7 +974,7 @@ class HubertModel(HubertPreTrainedModel):
...
@@ -974,7 +974,7 @@ class HubertModel(HubertPreTrainedModel):
self
.
feature_projection
=
HubertFeatureProjection
(
config
)
self
.
feature_projection
=
HubertFeatureProjection
(
config
)
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
if
config
.
do_stable_layer_norm
:
if
config
.
do_stable_layer_norm
:
self
.
encoder
=
HubertEncoderStableLayerNorm
(
config
)
self
.
encoder
=
HubertEncoderStableLayerNorm
(
config
)
...
@@ -1005,7 +1005,7 @@ class HubertModel(HubertPreTrainedModel):
...
@@ -1005,7 +1005,7 @@ class HubertModel(HubertPreTrainedModel):
if
mask_time_indices
is
not
None
:
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
...
@@ -1015,7 +1015,7 @@ class HubertModel(HubertPreTrainedModel):
...
@@ -1015,7 +1015,7 @@ class HubertModel(HubertPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/sew/modeling_sew.py
View file @
68be1d3c
...
@@ -834,7 +834,7 @@ class SEWModel(SEWPreTrainedModel):
...
@@ -834,7 +834,7 @@ class SEWModel(SEWPreTrainedModel):
self
.
feature_dropout
=
nn
.
Dropout
(
config
.
feat_proj_dropout
)
self
.
feature_dropout
=
nn
.
Dropout
(
config
.
feat_proj_dropout
)
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
encoder
=
SEWEncoder
(
config
)
self
.
encoder
=
SEWEncoder
(
config
)
...
@@ -862,7 +862,7 @@ class SEWModel(SEWPreTrainedModel):
...
@@ -862,7 +862,7 @@ class SEWModel(SEWPreTrainedModel):
if
mask_time_indices
is
not
None
:
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
...
@@ -872,7 +872,7 @@ class SEWModel(SEWPreTrainedModel):
...
@@ -872,7 +872,7 @@ class SEWModel(SEWPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/sew_d/modeling_sew_d.py
View file @
68be1d3c
...
@@ -1360,7 +1360,7 @@ class SEWDModel(SEWDPreTrainedModel):
...
@@ -1360,7 +1360,7 @@ class SEWDModel(SEWDPreTrainedModel):
self
.
feature_dropout
=
nn
.
Dropout
(
config
.
feat_proj_dropout
)
self
.
feature_dropout
=
nn
.
Dropout
(
config
.
feat_proj_dropout
)
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
encoder
=
SEWDEncoder
(
config
)
self
.
encoder
=
SEWDEncoder
(
config
)
...
@@ -1388,7 +1388,7 @@ class SEWDModel(SEWDPreTrainedModel):
...
@@ -1388,7 +1388,7 @@ class SEWDModel(SEWDPreTrainedModel):
if
mask_time_indices
is
not
None
:
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
...
@@ -1398,7 +1398,7 @@ class SEWDModel(SEWDPreTrainedModel):
...
@@ -1398,7 +1398,7 @@ class SEWDModel(SEWDPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/speecht5/modeling_speecht5.py
View file @
68be1d3c
...
@@ -517,7 +517,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
...
@@ -517,7 +517,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
# model only needs masking vector if mask prob is > 0.0
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
pos_conv_embed
=
SpeechT5PositionalConvEmbedding
(
config
)
self
.
pos_conv_embed
=
SpeechT5PositionalConvEmbedding
(
config
)
self
.
pos_sinusoidal_embed
=
SpeechT5SinusoidalPositionalEmbedding
(
self
.
pos_sinusoidal_embed
=
SpeechT5SinusoidalPositionalEmbedding
(
...
@@ -616,7 +616,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
...
@@ -616,7 +616,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
if
mask_time_indices
is
not
None
:
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
...
@@ -626,7 +626,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
...
@@ -626,7 +626,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
min_masks
=
self
.
config
.
mask_time_min_masks
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/unispeech/modeling_unispeech.py
View file @
68be1d3c
...
@@ -1090,7 +1090,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
...
@@ -1090,7 +1090,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
self
.
feature_projection
=
UniSpeechFeatureProjection
(
config
)
self
.
feature_projection
=
UniSpeechFeatureProjection
(
config
)
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
if
config
.
do_stable_layer_norm
:
if
config
.
do_stable_layer_norm
:
self
.
encoder
=
UniSpeechEncoderStableLayerNorm
(
config
)
self
.
encoder
=
UniSpeechEncoderStableLayerNorm
(
config
)
...
@@ -1121,7 +1121,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
...
@@ -1121,7 +1121,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
if
mask_time_indices
is
not
None
:
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
...
@@ -1131,7 +1131,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
...
@@ -1131,7 +1131,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
View file @
68be1d3c
...
@@ -1108,7 +1108,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
...
@@ -1108,7 +1108,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
self
.
feature_extractor
=
UniSpeechSatFeatureEncoder
(
config
)
self
.
feature_extractor
=
UniSpeechSatFeatureEncoder
(
config
)
self
.
feature_projection
=
UniSpeechSatFeatureProjection
(
config
)
self
.
feature_projection
=
UniSpeechSatFeatureProjection
(
config
)
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
if
config
.
do_stable_layer_norm
:
if
config
.
do_stable_layer_norm
:
self
.
encoder
=
UniSpeechSatEncoderStableLayerNorm
(
config
)
self
.
encoder
=
UniSpeechSatEncoderStableLayerNorm
(
config
)
...
@@ -1139,7 +1139,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
...
@@ -1139,7 +1139,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
if
mask_time_indices
is
not
None
:
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
...
@@ -1149,7 +1149,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
...
@@ -1149,7 +1149,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/wav2vec2/modeling_wav2vec2.py
View file @
68be1d3c
...
@@ -1445,7 +1445,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
...
@@ -1445,7 +1445,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
# model only needs masking vector if mask prob is > 0.0
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
if
config
.
do_stable_layer_norm
:
if
config
.
do_stable_layer_norm
:
self
.
encoder
=
Wav2Vec2EncoderStableLayerNorm
(
config
)
self
.
encoder
=
Wav2Vec2EncoderStableLayerNorm
(
config
)
...
@@ -1496,7 +1496,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
...
@@ -1496,7 +1496,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
if
mask_time_indices
is
not
None
:
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
...
@@ -1506,7 +1506,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
...
@@ -1506,7 +1506,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
View file @
68be1d3c
...
@@ -1053,7 +1053,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
...
@@ -1053,7 +1053,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
# model only needs masking vector if mask prob is > 0.0
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
encoder
=
Wav2Vec2BertEncoder
(
config
)
self
.
encoder
=
Wav2Vec2BertEncoder
(
config
)
...
@@ -1087,7 +1087,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
...
@@ -1087,7 +1087,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
if
mask_time_indices
is
not
None
:
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
...
@@ -1097,7 +1097,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
...
@@ -1097,7 +1097,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
View file @
68be1d3c
...
@@ -1235,7 +1235,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
...
@@ -1235,7 +1235,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
# model only needs masking vector if mask prob is > 0.0
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
encoder
=
Wav2Vec2ConformerEncoder
(
config
)
self
.
encoder
=
Wav2Vec2ConformerEncoder
(
config
)
...
@@ -1273,7 +1273,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
...
@@ -1273,7 +1273,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
if
mask_time_indices
is
not
None
:
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
...
@@ -1283,7 +1283,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
...
@@ -1283,7 +1283,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
# generate indices & apply SpecAugment along feature axis
...
...
src/transformers/models/wavlm/modeling_wavlm.py
View file @
68be1d3c
...
@@ -1107,7 +1107,7 @@ class WavLMModel(WavLMPreTrainedModel):
...
@@ -1107,7 +1107,7 @@ class WavLMModel(WavLMPreTrainedModel):
# model only needs masking vector if mask prob is > 0.0
# model only needs masking vector if mask prob is > 0.0
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
if
config
.
mask_time_prob
>
0.0
or
config
.
mask_feature_prob
>
0.0
:
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Float
Tensor
(
config
.
hidden_size
).
uniform_
())
self
.
masked_spec_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
).
uniform_
())
if
config
.
do_stable_layer_norm
:
if
config
.
do_stable_layer_norm
:
self
.
encoder
=
WavLMEncoderStableLayerNorm
(
config
)
self
.
encoder
=
WavLMEncoderStableLayerNorm
(
config
)
...
@@ -1158,7 +1158,7 @@ class WavLMModel(WavLMPreTrainedModel):
...
@@ -1158,7 +1158,7 @@ class WavLMModel(WavLMPreTrainedModel):
if
mask_time_indices
is
not
None
:
if
mask_time_indices
is
not
None
:
# apply SpecAugment along time axis with given mask_time_indices
# apply SpecAugment along time axis with given mask_time_indices
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
elif
self
.
config
.
mask_time_prob
>
0
and
self
.
training
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
...
@@ -1168,7 +1168,7 @@ class WavLMModel(WavLMPreTrainedModel):
...
@@ -1168,7 +1168,7 @@ class WavLMModel(WavLMPreTrainedModel):
min_masks
=
self
.
config
.
mask_time_min_masks
,
min_masks
=
self
.
config
.
mask_time_min_masks
,
)
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
# generate indices & apply SpecAugment along feature axis
# generate indices & apply SpecAugment along feature axis
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment