Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f1c71da1
Commit
f1c71da1
authored
Mar 12, 2020
by
Patrick von Platen
Browse files
fix eos_token_ids in test
parent
6047f46b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+3
-3
No files found.
tests/test_modeling_bart.py
View file @
f1c71da1
...
@@ -61,7 +61,7 @@ class ModelTester:
...
@@ -61,7 +61,7 @@ class ModelTester:
self
.
hidden_dropout_prob
=
0.1
self
.
hidden_dropout_prob
=
0.1
self
.
attention_probs_dropout_prob
=
0.1
self
.
attention_probs_dropout_prob
=
0.1
self
.
max_position_embeddings
=
20
self
.
max_position_embeddings
=
20
self
.
eos_token_id
=
2
self
.
eos_token_id
s
=
[
2
]
self
.
pad_token_id
=
1
self
.
pad_token_id
=
1
self
.
bos_token_id
=
0
self
.
bos_token_id
=
0
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -436,7 +436,7 @@ class BartModelIntegrationTest(unittest.TestCase):
...
@@ -436,7 +436,7 @@ class BartModelIntegrationTest(unittest.TestCase):
num_beams
=
4
,
num_beams
=
4
,
max_length
=
extra_len
+
2
,
max_length
=
extra_len
+
2
,
do_sample
=
False
,
do_sample
=
False
,
decoder_start_token_id
=
hf
.
config
.
eos_token_id
,
decoder_start_token_id
=
hf
.
config
.
eos_token_id
s
[
0
]
,
)
# repetition_penalty=10.,
)
# repetition_penalty=10.,
expected_result
=
"<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
expected_result
=
"<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
generated
=
[
tok
.
decode
(
g
,)
for
g
in
gen_tokens
]
generated
=
[
tok
.
decode
(
g
,)
for
g
in
gen_tokens
]
...
@@ -481,7 +481,7 @@ class BartModelIntegrationTest(unittest.TestCase):
...
@@ -481,7 +481,7 @@ class BartModelIntegrationTest(unittest.TestCase):
no_repeat_ngram_size
=
3
,
no_repeat_ngram_size
=
3
,
do_sample
=
False
,
do_sample
=
False
,
early_stopping
=
True
,
early_stopping
=
True
,
decoder_start_token_id
=
hf
.
config
.
eos_token_id
,
decoder_start_token_id
=
hf
.
config
.
eos_token_id
s
[
0
]
,
)
)
decoded
=
[
decoded
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment