Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
58bf8825
Unverified
Commit
58bf8825
authored
Oct 12, 2021
by
Patrick von Platen
Committed by
GitHub
Oct 12, 2021
Browse files
[Wav2Vec2] Make sure tensors are always bool for mask_indices (#13977)
* correct long to bool * up * correct code
parent
11c043d2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
4 deletions
+31
-4
src/transformers/models/hubert/modeling_hubert.py
src/transformers/models/hubert/modeling_hubert.py
+2
-2
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
+2
-2
tests/test_modeling_wav2vec2.py
tests/test_modeling_wav2vec2.py
+27
-0
No files found.
src/transformers/models/hubert/modeling_hubert.py
View file @
58bf8825
...
...
@@ -907,7 +907,7 @@ class HubertModel(HubertPreTrainedModel):
attention_mask
=
attention_mask
,
min_masks
=
2
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
long
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
...
...
@@ -917,7 +917,7 @@ class HubertModel(HubertPreTrainedModel):
mask_prob
=
self
.
config
.
mask_feature_prob
,
mask_length
=
self
.
config
.
mask_feature_length
,
)
mask_feature_indices
=
torch
.
tensor
(
mask_feature_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
long
)[
mask_feature_indices
=
torch
.
tensor
(
mask_feature_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)[
:,
None
].
expand
(
-
1
,
sequence_length
,
-
1
)
hidden_states
[
mask_feature_indices
]
=
0
...
...
src/transformers/models/wav2vec2/modeling_wav2vec2.py
View file @
58bf8825
...
...
@@ -1100,7 +1100,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
attention_mask
=
attention_mask
,
min_masks
=
2
,
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
long
)
mask_time_indices
=
torch
.
tensor
(
mask_time_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
if
self
.
config
.
mask_feature_prob
>
0
and
self
.
training
:
...
...
@@ -1110,7 +1110,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
mask_prob
=
self
.
config
.
mask_feature_prob
,
mask_length
=
self
.
config
.
mask_feature_length
,
)
mask_feature_indices
=
torch
.
tensor
(
mask_feature_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
long
)[
mask_feature_indices
=
torch
.
tensor
(
mask_feature_indices
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)[
:,
None
].
expand
(
-
1
,
sequence_length
,
-
1
)
hidden_states
[
mask_feature_indices
]
=
0
...
...
tests/test_modeling_wav2vec2.py
View file @
58bf8825
...
...
@@ -738,6 +738,33 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
self
.
assertEqual
(
logits
.
shape
,
(
4
,
1498
,
32
))
def
test_mask_time_feature_prob_ctc_single_batch
(
self
):
model
=
Wav2Vec2ForCTC
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
mask_time_prob
=
0.2
,
mask_feature_prob
=
0.2
,
mask_time_length
=
2
,
mask_feature_length
=
2
,
)
model
.
to
(
torch_device
).
train
()
processor
=
Wav2Vec2Processor
.
from_pretrained
(
"hf-internal-testing/tiny-random-wav2vec2"
,
return_attention_mask
=
True
)
batch_duration_in_seconds
=
[
6
]
input_features
=
[
np
.
random
.
random
(
16_000
*
s
)
for
s
in
batch_duration_in_seconds
]
batch
=
processor
(
input_features
,
padding
=
True
,
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
,
return_tensors
=
"pt"
)
logits
=
model
(
input_values
=
batch
[
"input_values"
].
to
(
torch_device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
torch_device
),
).
logits
self
.
assertEqual
(
logits
.
shape
,
(
1
,
1498
,
32
))
@
slow
def
test_model_from_pretrained
(
self
):
model
=
Wav2Vec2Model
.
from_pretrained
(
"facebook/wav2vec2-base-960h"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment