Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8bbe8247
"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "fcb991db6347a4c35e77b9fe8a15f10edf92bb7e"
Unverified
Commit
8bbe8247
authored
Oct 26, 2020
by
Sam Shleifer
Committed by
GitHub
Oct 26, 2020
Browse files
Cleanup pytorch tests (#8033)
parent
20a0894d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3 additions
and
31 deletions
+3
-31
tests/test_modeling_marian.py
tests/test_modeling_marian.py
+0
-1
tests/test_modeling_mbart.py
tests/test_modeling_mbart.py
+1
-28
tests/test_modeling_pegasus.py
tests/test_modeling_pegasus.py
+2
-2
No files found.
tests/test_modeling_marian.py
View file @
8bbe8247
...
...
@@ -37,7 +37,6 @@ if is_torch_available():
from
transformers.pipelines
import
TranslationPipeline
@
require_torch
class
ModelTester
:
def
__init__
(
self
,
parent
):
self
.
config
=
MarianConfig
(
...
...
tests/test_modeling_mbart.py
View file @
8bbe8247
...
...
@@ -4,7 +4,6 @@ from transformers import is_torch_available
from
transformers.file_utils
import
cached_property
from
transformers.testing_utils
import
require_sentencepiece
,
require_tokenizers
,
require_torch
,
slow
,
torch_device
from
.test_modeling_bart
import
TOLERANCE
,
_long_tensor
,
assert_tensors_close
from
.test_modeling_common
import
ModelTesterMixin
...
...
@@ -91,32 +90,6 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
]
expected_src_tokens
=
[
8274
,
127873
,
25916
,
7
,
8622
,
2071
,
438
,
67485
,
53
,
187895
,
23
,
51712
,
2
,
EN_CODE
]
@
slow
@
unittest
.
skip
(
"This has been failing since June 20th at least."
)
def
test_enro_forward
(
self
):
model
=
self
.
model
net_input
=
{
"input_ids"
:
_long_tensor
(
[
[
3493
,
3060
,
621
,
104064
,
1810
,
100
,
142
,
566
,
13158
,
6889
,
5
,
2
,
250004
],
[
64511
,
7
,
765
,
2837
,
45188
,
297
,
4049
,
237
,
10
,
122122
,
5
,
2
,
250004
],
]
),
"decoder_input_ids"
:
_long_tensor
(
[
[
250020
,
31952
,
144
,
9019
,
242307
,
21980
,
55749
,
11
,
5
,
2
,
1
,
1
],
[
250020
,
884
,
9019
,
96
,
9
,
916
,
86792
,
36
,
18743
,
15596
,
5
,
2
],
]
),
}
net_input
[
"attention_mask"
]
=
net_input
[
"input_ids"
].
ne
(
1
)
with
torch
.
no_grad
():
logits
,
*
other_stuff
=
model
(
**
net_input
)
expected_slice
=
torch
.
tensor
([
9.0078
,
10.1113
,
14.4787
],
device
=
logits
.
device
,
dtype
=
logits
.
dtype
)
result_slice
=
logits
[
0
,
0
,
:
3
]
assert_tensors_close
(
expected_slice
,
result_slice
,
atol
=
TOLERANCE
)
@
slow
def
test_enro_generate_one
(
self
):
batch
:
BatchEncoding
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
...
...
@@ -128,7 +101,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
# self.assertEqual(self.tgt_text[1], decoded[1])
@
slow
def
test_enro_generate
(
self
):
def
test_enro_generate
_batch
(
self
):
batch
:
BatchEncoding
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
self
.
src_text
).
to
(
torch_device
)
translated_tokens
=
self
.
model
.
generate
(
**
batch
)
decoded
=
self
.
tokenizer
.
batch_decode
(
translated_tokens
,
skip_special_tokens
=
True
)
...
...
tests/test_modeling_pegasus.py
View file @
8bbe8247
...
...
@@ -58,7 +58,7 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
src_text
=
[
PGE_ARTICLE
,
XSUM_ENTRY_LONGER
]
tgt_text
=
[
"California's largest electricity provider has turned off power to hundreds of thousands of customers."
,
"N-Dubz have
sai
d they were surprised to get four nominations for this year's Mobo Awards."
,
"
Pop group
N-Dubz have
reveale
d they were surprised to get four nominations for this year's Mobo Awards."
,
]
@
cached_property
...
...
@@ -72,7 +72,7 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
torch_device
)
assert
inputs
.
input_ids
.
shape
==
(
2
,
421
)
translated_tokens
=
self
.
model
.
generate
(
**
inputs
)
translated_tokens
=
self
.
model
.
generate
(
**
inputs
,
num_beams
=
2
)
decoded
=
self
.
tokenizer
.
batch_decode
(
translated_tokens
,
skip_special_tokens
=
True
)
assert
self
.
tgt_text
==
decoded
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment