Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
27d348f2
Unverified
Commit
27d348f2
authored
Jul 01, 2021
by
Patrick von Platen
Committed by
GitHub
Jul 01, 2021
Browse files
[Wav2Vec2, Hubert] Fix ctc loss test (#12458)
* fix_torch_device_generate_test * remove @ * fix test
parent
b655f16d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
6 deletions
+8
-6
tests/test_modeling_hubert.py
tests/test_modeling_hubert.py
+4
-3
tests/test_modeling_wav2vec2.py
tests/test_modeling_wav2vec2.py
+4
-3
No files found.
tests/test_modeling_hubert.py
View file @
27d348f2
...
@@ -176,12 +176,13 @@ class HubertModelTester:
...
@@ -176,12 +176,13 @@ class HubertModelTester:
attention_mask
[
i
,
input_lengths
[
i
]
:]
=
0
attention_mask
[
i
,
input_lengths
[
i
]
:]
=
0
model
.
config
.
ctc_loss_reduction
=
"sum"
model
.
config
.
ctc_loss_reduction
=
"sum"
sum_loss
=
model
(
input_values
,
attention_mask
=
attention_mask
,
labels
=
labels
).
loss
sum_loss
=
model
(
input_values
,
attention_mask
=
attention_mask
,
labels
=
labels
).
loss
.
item
()
model
.
config
.
ctc_loss_reduction
=
"mean"
model
.
config
.
ctc_loss_reduction
=
"mean"
mean_loss
=
model
(
input_values
,
attention_mask
=
attention_mask
,
labels
=
labels
).
loss
mean_loss
=
model
(
input_values
,
attention_mask
=
attention_mask
,
labels
=
labels
).
loss
.
item
()
self
.
parent
.
assertTrue
(
abs
(
labels
.
shape
[
0
]
*
labels
.
shape
[
1
]
*
mean_loss
.
item
()
-
sum_loss
.
item
())
<
1e-3
)
self
.
parent
.
assertTrue
(
isinstance
(
sum_loss
,
float
))
self
.
parent
.
assertTrue
(
isinstance
(
mean_loss
,
float
))
def
check_training
(
self
,
config
,
input_values
,
*
args
):
def
check_training
(
self
,
config
,
input_values
,
*
args
):
config
.
ctc_zero_infinity
=
True
config
.
ctc_zero_infinity
=
True
...
...
tests/test_modeling_wav2vec2.py
View file @
27d348f2
...
@@ -184,12 +184,13 @@ class Wav2Vec2ModelTester:
...
@@ -184,12 +184,13 @@ class Wav2Vec2ModelTester:
attention_mask
[
i
,
input_lengths
[
i
]
:]
=
0
attention_mask
[
i
,
input_lengths
[
i
]
:]
=
0
model
.
config
.
ctc_loss_reduction
=
"sum"
model
.
config
.
ctc_loss_reduction
=
"sum"
sum_loss
=
model
(
input_values
,
attention_mask
=
attention_mask
,
labels
=
labels
).
loss
sum_loss
=
model
(
input_values
,
attention_mask
=
attention_mask
,
labels
=
labels
).
loss
.
item
()
model
.
config
.
ctc_loss_reduction
=
"mean"
model
.
config
.
ctc_loss_reduction
=
"mean"
mean_loss
=
model
(
input_values
,
attention_mask
=
attention_mask
,
labels
=
labels
).
loss
mean_loss
=
model
(
input_values
,
attention_mask
=
attention_mask
,
labels
=
labels
).
loss
.
item
()
self
.
parent
.
assertTrue
(
abs
(
labels
.
shape
[
0
]
*
labels
.
shape
[
1
]
*
mean_loss
.
item
()
-
sum_loss
.
item
())
<
1e-3
)
self
.
parent
.
assertTrue
(
isinstance
(
sum_loss
,
float
))
self
.
parent
.
assertTrue
(
isinstance
(
mean_loss
,
float
))
def
check_training
(
self
,
config
,
input_values
,
*
args
):
def
check_training
(
self
,
config
,
input_values
,
*
args
):
config
.
ctc_zero_infinity
=
True
config
.
ctc_zero_infinity
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment