Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0f5488f7
Unverified
Commit
0f5488f7
authored
Oct 07, 2021
by
Patrick von Platen
Committed by
GitHub
Oct 07, 2021
Browse files
[Wav2Vec2] Fix mask_feature_prob (#13921)
* up * overwrite hubert
parent
57420b10
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
93 additions
and
2 deletions
+93
-2
src/transformers/models/hubert/modeling_hubert.py
src/transformers/models/hubert/modeling_hubert.py
+0
-1
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
+0
-1
tests/test_modeling_wav2vec2.py
tests/test_modeling_wav2vec2.py
+93
-0
No files found.
src/transformers/models/hubert/modeling_hubert.py
View file @
0f5488f7
...
@@ -877,7 +877,6 @@ class HubertModel(HubertPreTrainedModel):
...
@@ -877,7 +877,6 @@ class HubertModel(HubertPreTrainedModel):
mask_prob
=
self
.
config
.
mask_feature_prob
,
mask_prob
=
self
.
config
.
mask_feature_prob
,
mask_length
=
self
.
config
.
mask_feature_length
,
mask_length
=
self
.
config
.
mask_feature_length
,
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
attention_mask
=
attention_mask
,
)
)
hidden_states
[
mask_feature_indices
[:,
None
].
expand
(
-
1
,
sequence_length
,
-
1
)]
=
0
hidden_states
[
mask_feature_indices
[:,
None
].
expand
(
-
1
,
sequence_length
,
-
1
)]
=
0
...
...
src/transformers/models/wav2vec2/modeling_wav2vec2.py
View file @
0f5488f7
...
@@ -1014,7 +1014,6 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
...
@@ -1014,7 +1014,6 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
mask_prob
=
self
.
config
.
mask_feature_prob
,
mask_prob
=
self
.
config
.
mask_feature_prob
,
mask_length
=
self
.
config
.
mask_feature_length
,
mask_length
=
self
.
config
.
mask_feature_length
,
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
attention_mask
=
attention_mask
,
)
)
hidden_states
[
mask_feature_indices
[:,
None
].
expand
(
-
1
,
sequence_length
,
-
1
)]
=
0
hidden_states
[
mask_feature_indices
[:,
None
].
expand
(
-
1
,
sequence_length
,
-
1
)]
=
0
...
...
tests/test_modeling_wav2vec2.py
View file @
0f5488f7
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
math
import
math
import
unittest
import
unittest
import
numpy
as
np
import
pytest
import
pytest
from
tests.test_modeling_common
import
floats_tensor
,
ids_tensor
,
random_attention_mask
from
tests.test_modeling_common
import
floats_tensor
,
ids_tensor
,
random_attention_mask
...
@@ -433,6 +434,52 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -433,6 +434,52 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
if
hasattr
(
module
,
"masked_spec_embed"
)
and
module
.
masked_spec_embed
is
not
None
:
if
hasattr
(
module
,
"masked_spec_embed"
)
and
module
.
masked_spec_embed
is
not
None
:
module
.
masked_spec_embed
.
data
.
fill_
(
3
)
module
.
masked_spec_embed
.
data
.
fill_
(
3
)
def
test_mask_feature_prob_ctc
(
self
):
model
=
Wav2Vec2ForCTC
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
mask_feature_prob
=
0.2
,
mask_feature_length
=
2
)
model
.
to
(
torch_device
).
train
()
processor
=
Wav2Vec2Processor
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
return_attention_mask
=
True
)
batch_duration_in_seconds
=
[
1
,
3
,
2
,
6
]
input_features
=
[
np
.
random
.
random
(
16_000
*
s
)
for
s
in
batch_duration_in_seconds
]
batch
=
processor
(
input_features
,
padding
=
True
,
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
,
return_tensors
=
"pt"
)
logits
=
model
(
input_values
=
batch
[
"input_values"
].
to
(
torch_device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
torch_device
),
).
logits
self
.
assertEqual
(
logits
.
shape
,
(
4
,
1498
,
32
))
def
test_mask_time_prob_ctc
(
self
):
model
=
Wav2Vec2ForCTC
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
mask_time_prob
=
0.2
,
mask_time_length
=
2
)
model
.
to
(
torch_device
).
train
()
processor
=
Wav2Vec2Processor
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
return_attention_mask
=
True
)
batch_duration_in_seconds
=
[
1
,
3
,
2
,
6
]
input_features
=
[
np
.
random
.
random
(
16_000
*
s
)
for
s
in
batch_duration_in_seconds
]
batch
=
processor
(
input_features
,
padding
=
True
,
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
,
return_tensors
=
"pt"
)
logits
=
model
(
input_values
=
batch
[
"input_values"
].
to
(
torch_device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
torch_device
),
).
logits
self
.
assertEqual
(
logits
.
shape
,
(
4
,
1498
,
32
))
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
model
=
Wav2Vec2Model
.
from_pretrained
(
"facebook/wav2vec2-base-960h"
)
model
=
Wav2Vec2Model
.
from_pretrained
(
"facebook/wav2vec2-base-960h"
)
...
@@ -620,6 +667,52 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -620,6 +667,52 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
self
.
assertTrue
(
loss
.
detach
().
item
()
<=
loss_more_masked
.
detach
().
item
())
self
.
assertTrue
(
loss
.
detach
().
item
()
<=
loss_more_masked
.
detach
().
item
())
def
test_mask_feature_prob_ctc
(
self
):
model
=
Wav2Vec2ForCTC
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
mask_feature_prob
=
0.2
,
mask_feature_length
=
2
)
model
.
to
(
torch_device
).
train
()
processor
=
Wav2Vec2Processor
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
return_attention_mask
=
True
)
batch_duration_in_seconds
=
[
1
,
3
,
2
,
6
]
input_features
=
[
np
.
random
.
random
(
16_000
*
s
)
for
s
in
batch_duration_in_seconds
]
batch
=
processor
(
input_features
,
padding
=
True
,
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
,
return_tensors
=
"pt"
)
logits
=
model
(
input_values
=
batch
[
"input_values"
].
to
(
torch_device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
torch_device
),
).
logits
self
.
assertEqual
(
logits
.
shape
,
(
4
,
1498
,
32
))
def
test_mask_time_prob_ctc
(
self
):
model
=
Wav2Vec2ForCTC
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
mask_time_prob
=
0.2
,
mask_time_length
=
2
)
model
.
to
(
torch_device
).
train
()
processor
=
Wav2Vec2Processor
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
return_attention_mask
=
True
)
batch_duration_in_seconds
=
[
1
,
3
,
2
,
6
]
input_features
=
[
np
.
random
.
random
(
16_000
*
s
)
for
s
in
batch_duration_in_seconds
]
batch
=
processor
(
input_features
,
padding
=
True
,
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
,
return_tensors
=
"pt"
)
logits
=
model
(
input_values
=
batch
[
"input_values"
].
to
(
torch_device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
torch_device
),
).
logits
self
.
assertEqual
(
logits
.
shape
,
(
4
,
1498
,
32
))
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
model
=
Wav2Vec2Model
.
from_pretrained
(
"facebook/wav2vec2-base-960h"
)
model
=
Wav2Vec2Model
.
from_pretrained
(
"facebook/wav2vec2-base-960h"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment