Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7630c11f
Unverified
Commit
7630c11f
authored
May 25, 2021
by
Patrick von Platen
Committed by
GitHub
May 25, 2021
Browse files
[Wav2Vec2] SpecAugment Fast (#11764)
* first try * finish
parent
f086652b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
78 deletions
+49
-78
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
+47
-51
tests/test_modeling_wav2vec2.py
tests/test_modeling_wav2vec2.py
+2
-27
No files found.
src/transformers/models/wav2vec2/modeling_wav2vec2.py
View file @
7630c11f
...
@@ -48,71 +48,67 @@ def _compute_mask_indices(
...
@@ -48,71 +48,67 @@ def _compute_mask_indices(
shape
:
Tuple
[
int
,
int
],
shape
:
Tuple
[
int
,
int
],
mask_prob
:
float
,
mask_prob
:
float
,
mask_length
:
int
,
mask_length
:
int
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
Non
e
,
device
:
torch
.
devic
e
,
min_masks
:
int
=
0
,
min_masks
:
int
=
0
,
)
->
np
.
ndarray
:
)
->
torch
.
tensor
:
"""
"""
Computes random mask spans for a given shape
Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
ASR <https://arxiv.org/abs/1904.08779>`__.
Args:
Args:
shape: the the shape for which to compute masks.
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
should be of size 2 where first element is batch size and 2nd is timesteps
attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_length: size of the mask
mask_length: size of the mask
min_masks: minimum number of masked spans
min_masks: minimum number of masked spans
Adapted from `fairseq's data_utils.py
<https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376>`__.
"""
"""
bsz
,
all_sz
=
shape
batch_size
,
sequence_length
=
shape
mask
=
np
.
full
((
bsz
,
all_sz
),
False
)
all_num_mask
=
int
(
if
mask_length
<
1
:
# add a random number for probabilistic rounding
raise
ValueError
(
"`mask_length` has to be bigger than 0."
)
mask_prob
*
all_sz
/
float
(
mask_length
)
+
np
.
random
.
rand
()
)
all_num_mask
=
max
(
min_masks
,
all_num_mask
)
if
mask_length
>
sequence_length
:
raise
ValueError
(
mask_idcs
=
[]
f
"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`:
{
mask_length
}
and `sequence_length`:
{
sequence_length
}
`"
padding_mask
=
attention_mask
.
ne
(
1
)
if
attention_mask
is
not
None
else
None
for
i
in
range
(
bsz
):
if
padding_mask
is
not
None
:
sz
=
all_sz
-
padding_mask
[
i
].
long
().
sum
().
item
()
num_mask
=
int
(
# add a random number for probabilistic rounding
mask_prob
*
sz
/
float
(
mask_length
)
+
np
.
random
.
rand
()
)
)
num_mask
=
max
(
min_masks
,
num_mask
)
else
:
sz
=
all_sz
num_mask
=
all_num_mask
lengths
=
np
.
full
(
num_mask
,
mask_length
)
# compute number of masked spans in batch
num_masked_spans
=
int
(
mask_prob
*
sequence_length
/
mask_length
+
torch
.
rand
((
1
,)).
item
())
num_masked_spans
=
max
(
num_masked_spans
,
min_masks
)
if
sum
(
lengths
)
==
0
:
# make sure num masked indices <= sequence_length
lengths
[
0
]
=
min
(
mask_length
,
sz
-
1
)
if
num_masked_spans
*
mask_length
>
sequence_length
:
num_masked_spans
=
sequence_length
//
mask_length
min_len
=
min
(
lengths
)
# SpecAugment mask to fill
if
sz
-
min_len
<=
num_mask
:
spec_aug_mask
=
torch
.
zeros
((
batch_size
,
sequence_length
),
device
=
device
,
dtype
=
torch
.
bool
)
min_len
=
sz
-
num_mask
-
1
mask_idc
=
np
.
random
.
choice
(
sz
-
min_len
,
num_mask
,
replace
=
False
)
# uniform distribution to sample from, make sure that offset samples are < sequence_length
mask_idc
=
np
.
asarray
([
mask_idc
[
j
]
+
offset
for
j
in
range
(
len
(
mask_idc
))
for
offset
in
range
(
lengths
[
j
])])
uniform_dist
=
torch
.
ones
((
batch_size
,
sequence_length
-
(
mask_length
-
1
)),
device
=
device
)
mask_idcs
.
append
(
np
.
unique
(
mask_idc
[
mask_idc
<
sz
]))
min_len
=
min
([
len
(
m
)
for
m
in
mask_idcs
])
# get random indices to mask
for
i
,
mask_idc
in
enumerate
(
mask_idcs
):
spec_aug_mask_idxs
=
torch
.
multinomial
(
uniform_dist
,
num_masked_spans
)
if
len
(
mask_idc
)
>
min_len
:
mask_idc
=
np
.
random
.
choice
(
mask_idc
,
min_len
,
replace
=
False
)
mask
[
i
,
mask_idc
]
=
True
return
mask
# expand masked indices to masked spans
spec_aug_mask_idxs
=
(
spec_aug_mask_idxs
.
unsqueeze
(
dim
=-
1
)
.
expand
((
batch_size
,
num_masked_spans
,
mask_length
))
.
reshape
(
batch_size
,
num_masked_spans
*
mask_length
)
)
offsets
=
(
torch
.
arange
(
mask_length
,
device
=
device
)[
None
,
None
,
:]
.
expand
((
batch_size
,
num_masked_spans
,
mask_length
))
.
reshape
(
batch_size
,
num_masked_spans
*
mask_length
)
)
spec_aug_mask_idxs
=
spec_aug_mask_idxs
+
offsets
# scatter indices to mask
spec_aug_mask
=
spec_aug_mask
.
scatter
(
1
,
spec_aug_mask_idxs
,
True
)
return
spec_aug_mask
class
Wav2Vec2NoLayerNormConvLayer
(
nn
.
Module
):
class
Wav2Vec2NoLayerNormConvLayer
(
nn
.
Module
):
...
@@ -847,21 +843,21 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
...
@@ -847,21 +843,21 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
if
self
.
config
.
mask_time_prob
>
0
:
if
self
.
config
.
mask_time_prob
>
0
:
mask_time_indices
=
_compute_mask_indices
(
mask_time_indices
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
(
batch_size
,
sequence_length
),
self
.
config
.
mask_time_prob
,
mask_prob
=
self
.
config
.
mask_time_prob
,
self
.
config
.
mask_time_length
,
mask_length
=
self
.
config
.
mask_time_length
,
attention_mask
=
attention_mask
,
device
=
hidden_states
.
device
,
min_masks
=
2
,
min_masks
=
2
,
)
)
hidden_states
[
torch
.
from_numpy
(
mask_time_indices
)
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
hidden_states
[
mask_time_indices
]
=
self
.
masked_spec_embed
.
to
(
hidden_states
.
dtype
)
# apply SpecAugment along feature axis
# apply SpecAugment along feature axis
if
self
.
config
.
mask_feature_prob
>
0
:
if
self
.
config
.
mask_feature_prob
>
0
:
mask_feature_indices
=
_compute_mask_indices
(
mask_feature_indices
=
_compute_mask_indices
(
(
batch_size
,
hidden_size
),
(
batch_size
,
hidden_size
),
self
.
config
.
mask_feature_prob
,
mask_prob
=
self
.
config
.
mask_feature_prob
,
self
.
config
.
mask_feature_length
,
mask_length
=
self
.
config
.
mask_feature_length
,
device
=
hidden_states
.
device
,
)
)
mask_feature_indices
=
torch
.
from_numpy
(
mask_feature_indices
).
to
(
hidden_states
.
device
)
hidden_states
[
mask_feature_indices
[:,
None
].
expand
(
-
1
,
sequence_length
,
-
1
)]
=
0
hidden_states
[
mask_feature_indices
[:,
None
].
expand
(
-
1
,
sequence_length
,
-
1
)]
=
0
encoder_outputs
=
self
.
encoder
(
encoder_outputs
=
self
.
encoder
(
...
...
tests/test_modeling_wav2vec2.py
View file @
7630c11f
...
@@ -478,26 +478,17 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
...
@@ -478,26 +478,17 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
mask_prob
=
0.5
mask_prob
=
0.5
mask_length
=
1
mask_length
=
1
mask
=
_compute_mask_indices
((
batch_size
,
sequence_length
),
mask_prob
,
mask_length
)
mask
=
_compute_mask_indices
((
batch_size
,
sequence_length
),
mask_prob
,
mask_length
,
torch_device
)
self
.
assertListEqual
(
mask
.
sum
(
axis
=-
1
).
tolist
(),
[
mask_prob
*
sequence_length
for
_
in
range
(
batch_size
)])
self
.
assertListEqual
(
mask
.
sum
(
axis
=-
1
).
tolist
(),
[
mask_prob
*
sequence_length
for
_
in
range
(
batch_size
)])
attention_mask
=
torch
.
ones
((
batch_size
,
sequence_length
),
device
=
torch_device
,
dtype
=
torch
.
long
)
attention_mask
[:,
-
sequence_length
//
2
:]
=
0
mask
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
mask_prob
,
mask_length
,
attention_mask
=
attention_mask
)
self
.
assertListEqual
(
mask
.
sum
(
axis
=-
1
).
tolist
(),
[
mask_prob
*
sequence_length
//
2
for
_
in
range
(
batch_size
)])
def
test_compute_mask_indices_overlap
(
self
):
def
test_compute_mask_indices_overlap
(
self
):
batch_size
=
4
batch_size
=
4
sequence_length
=
60
sequence_length
=
60
mask_prob
=
0.5
mask_prob
=
0.5
mask_length
=
4
mask_length
=
4
mask
=
_compute_mask_indices
((
batch_size
,
sequence_length
),
mask_prob
,
mask_length
)
mask
=
_compute_mask_indices
((
batch_size
,
sequence_length
),
mask_prob
,
mask_length
,
torch_device
)
# because of overlap there is a range of possible masks
# because of overlap there is a range of possible masks
for
batch_sum
in
mask
.
sum
(
axis
=-
1
):
for
batch_sum
in
mask
.
sum
(
axis
=-
1
):
...
@@ -506,22 +497,6 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
...
@@ -506,22 +497,6 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
list
(
range
(
int
(
mask_prob
//
mask_length
*
sequence_length
),
int
(
mask_prob
*
sequence_length
))),
list
(
range
(
int
(
mask_prob
//
mask_length
*
sequence_length
),
int
(
mask_prob
*
sequence_length
))),
)
)
attention_mask
=
torch
.
ones
((
batch_size
,
sequence_length
),
device
=
torch_device
,
dtype
=
torch
.
long
)
attention_mask
[:,
-
sequence_length
//
2
:]
=
0
mask
=
_compute_mask_indices
(
(
batch_size
,
sequence_length
),
mask_prob
,
mask_length
,
attention_mask
=
attention_mask
)
# because of overlap there is a range of possible masks
for
batch_sum
in
mask
.
sum
(
axis
=-
1
):
self
.
assertIn
(
int
(
batch_sum
),
list
(
range
(
int
(
mask_prob
//
mask_length
*
sequence_length
//
2
),
int
(
mask_prob
*
sequence_length
//
2
))
),
)
@
require_torch
@
require_torch
@
slow
@
slow
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment