Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
56200331
Unverified
Commit
56200331
authored
Jun 11, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 11, 2020
Browse files
[mbart] Fix fp16 testing logic (#4949)
parent
473808da
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
4 deletions
+1
-4
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+1
-4
No files found.
tests/test_modeling_bart.py
View file @
56200331
...
@@ -249,7 +249,7 @@ class MBartIntegrationTests(unittest.TestCase):
...
@@ -249,7 +249,7 @@ class MBartIntegrationTests(unittest.TestCase):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
logits
,
*
other_stuff
=
model
(
**
net_input
)
logits
,
*
other_stuff
=
model
(
**
net_input
)
expected_slice
=
torch
.
tensor
([
9.0078
,
10.1113
,
14.4787
],
device
=
torch_device
)
expected_slice
=
torch
.
tensor
([
9.0078
,
10.1113
,
14.4787
],
device
=
torch_device
,
dtype
=
model
.
dtype
)
result_slice
=
logits
[
0
][
0
][:
3
]
result_slice
=
logits
[
0
][
0
][:
3
]
self
.
assertTrue
(
torch
.
allclose
(
expected_slice
,
result_slice
,
atol
=
TOLERANCE
))
self
.
assertTrue
(
torch
.
allclose
(
expected_slice
,
result_slice
,
atol
=
TOLERANCE
))
...
@@ -293,9 +293,6 @@ class MBartIntegrationTests(unittest.TestCase):
...
@@ -293,9 +293,6 @@ class MBartIntegrationTests(unittest.TestCase):
expected_shape
=
(
*
summary
.
shape
,
config
.
vocab_size
)
expected_shape
=
(
*
summary
.
shape
,
config
.
vocab_size
)
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
@
require_torch
class
MBartTokenizerTests
(
MBartIntegrationTests
):
def
test_enro_tokenizer_prepare_translation_batch
(
self
):
def
test_enro_tokenizer_prepare_translation_batch
(
self
):
batch
=
self
.
tokenizer
.
prepare_translation_batch
(
batch
=
self
.
tokenizer
.
prepare_translation_batch
(
self
.
src_text
,
tgt_texts
=
self
.
tgt_text
,
max_length
=
len
(
self
.
expected_src_tokens
),
self
.
src_text
,
tgt_texts
=
self
.
tgt_text
,
max_length
=
len
(
self
.
expected_src_tokens
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment