Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0f5488f7
Unverified
Commit
0f5488f7
authored
Oct 07, 2021
by
Patrick von Platen
Committed by
GitHub
Oct 07, 2021
Browse files
[Wav2Vec2] Fix mask_feature_prob (#13921)
* up * overwrite hubert
parent
57420b10
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
93 additions
and
2 deletions
+93
-2
src/transformers/models/hubert/modeling_hubert.py
src/transformers/models/hubert/modeling_hubert.py
+0
-1
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
+0
-1
tests/test_modeling_wav2vec2.py
tests/test_modeling_wav2vec2.py
+93
-0
No files found.
src/transformers/models/hubert/modeling_hubert.py
View file @
0f5488f7
...
...
@@ -877,7 +877,6 @@ class HubertModel(HubertPreTrainedModel):
mask_prob
=
self
.
config
.
mask_feature_prob
,
mask_length
=
self
.
config
.
mask_feature_length
,
device
=
hidden_states
.
device
,
attention_mask
=
attention_mask
,
)
hidden_states
[
mask_feature_indices
[:,
None
].
expand
(
-
1
,
sequence_length
,
-
1
)]
=
0
...
...
src/transformers/models/wav2vec2/modeling_wav2vec2.py
View file @
0f5488f7
...
...
@@ -1014,7 +1014,6 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
mask_prob
=
self
.
config
.
mask_feature_prob
,
mask_length
=
self
.
config
.
mask_feature_length
,
device
=
hidden_states
.
device
,
attention_mask
=
attention_mask
,
)
hidden_states
[
mask_feature_indices
[:,
None
].
expand
(
-
1
,
sequence_length
,
-
1
)]
=
0
...
...
tests/test_modeling_wav2vec2.py
View file @
0f5488f7
...
...
@@ -17,6 +17,7 @@
import
math
import
unittest
import
numpy
as
np
import
pytest
from
tests.test_modeling_common
import
floats_tensor
,
ids_tensor
,
random_attention_mask
...
...
@@ -433,6 +434,52 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
if
hasattr
(
module
,
"masked_spec_embed"
)
and
module
.
masked_spec_embed
is
not
None
:
module
.
masked_spec_embed
.
data
.
fill_
(
3
)
def
test_mask_feature_prob_ctc
(
self
):
model
=
Wav2Vec2ForCTC
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
mask_feature_prob
=
0.2
,
mask_feature_length
=
2
)
model
.
to
(
torch_device
).
train
()
processor
=
Wav2Vec2Processor
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
return_attention_mask
=
True
)
batch_duration_in_seconds
=
[
1
,
3
,
2
,
6
]
input_features
=
[
np
.
random
.
random
(
16_000
*
s
)
for
s
in
batch_duration_in_seconds
]
batch
=
processor
(
input_features
,
padding
=
True
,
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
,
return_tensors
=
"pt"
)
logits
=
model
(
input_values
=
batch
[
"input_values"
].
to
(
torch_device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
torch_device
),
).
logits
self
.
assertEqual
(
logits
.
shape
,
(
4
,
1498
,
32
))
def
test_mask_time_prob_ctc
(
self
):
model
=
Wav2Vec2ForCTC
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
mask_time_prob
=
0.2
,
mask_time_length
=
2
)
model
.
to
(
torch_device
).
train
()
processor
=
Wav2Vec2Processor
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
return_attention_mask
=
True
)
batch_duration_in_seconds
=
[
1
,
3
,
2
,
6
]
input_features
=
[
np
.
random
.
random
(
16_000
*
s
)
for
s
in
batch_duration_in_seconds
]
batch
=
processor
(
input_features
,
padding
=
True
,
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
,
return_tensors
=
"pt"
)
logits
=
model
(
input_values
=
batch
[
"input_values"
].
to
(
torch_device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
torch_device
),
).
logits
self
.
assertEqual
(
logits
.
shape
,
(
4
,
1498
,
32
))
@
slow
def
test_model_from_pretrained
(
self
):
model
=
Wav2Vec2Model
.
from_pretrained
(
"facebook/wav2vec2-base-960h"
)
...
...
@@ -620,6 +667,52 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
self
.
assertTrue
(
loss
.
detach
().
item
()
<=
loss_more_masked
.
detach
().
item
())
def
test_mask_feature_prob_ctc
(
self
):
model
=
Wav2Vec2ForCTC
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
mask_feature_prob
=
0.2
,
mask_feature_length
=
2
)
model
.
to
(
torch_device
).
train
()
processor
=
Wav2Vec2Processor
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
return_attention_mask
=
True
)
batch_duration_in_seconds
=
[
1
,
3
,
2
,
6
]
input_features
=
[
np
.
random
.
random
(
16_000
*
s
)
for
s
in
batch_duration_in_seconds
]
batch
=
processor
(
input_features
,
padding
=
True
,
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
,
return_tensors
=
"pt"
)
logits
=
model
(
input_values
=
batch
[
"input_values"
].
to
(
torch_device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
torch_device
),
).
logits
self
.
assertEqual
(
logits
.
shape
,
(
4
,
1498
,
32
))
def
test_mask_time_prob_ctc
(
self
):
model
=
Wav2Vec2ForCTC
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
mask_time_prob
=
0.2
,
mask_time_length
=
2
)
model
.
to
(
torch_device
).
train
()
processor
=
Wav2Vec2Processor
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
return_attention_mask
=
True
)
batch_duration_in_seconds
=
[
1
,
3
,
2
,
6
]
input_features
=
[
np
.
random
.
random
(
16_000
*
s
)
for
s
in
batch_duration_in_seconds
]
batch
=
processor
(
input_features
,
padding
=
True
,
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
,
return_tensors
=
"pt"
)
logits
=
model
(
input_values
=
batch
[
"input_values"
].
to
(
torch_device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
torch_device
),
).
logits
self
.
assertEqual
(
logits
.
shape
,
(
4
,
1498
,
32
))
@
slow
def
test_model_from_pretrained
(
self
):
model
=
Wav2Vec2Model
.
from_pretrained
(
"facebook/wav2vec2-base-960h"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment