Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5abe5038
Unverified
Commit
5abe5038
authored
Jul 28, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 28, 2020
Browse files
Fix #6096: MBartTokenizer's mask token (#6098)
parent
b1c8b769
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
0 deletions
+24
-0
src/transformers/tokenization_bart.py
src/transformers/tokenization_bart.py
+1
-0
tests/test_modeling_mbart.py
tests/test_modeling_mbart.py
+12
-0
tests/test_tokenization_mbart.py
tests/test_tokenization_mbart.py
+11
-0
No files found.
src/transformers/tokenization_bart.py
View file @
5abe5038
...
@@ -122,6 +122,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -122,6 +122,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
}
}
self
.
id_to_lang_code
=
{
v
:
k
for
k
,
v
in
self
.
lang_code_to_id
.
items
()}
self
.
id_to_lang_code
=
{
v
:
k
for
k
,
v
in
self
.
lang_code_to_id
.
items
()}
self
.
cur_lang_code
=
self
.
lang_code_to_id
[
"en_XX"
]
self
.
cur_lang_code
=
self
.
lang_code_to_id
[
"en_XX"
]
self
.
fairseq_tokens_to_ids
[
"<mask>"
]
=
len
(
self
.
sp_model
)
+
len
(
self
.
lang_code_to_id
)
+
self
.
fairseq_offset
self
.
fairseq_tokens_to_ids
.
update
(
self
.
lang_code_to_id
)
self
.
fairseq_tokens_to_ids
.
update
(
self
.
lang_code_to_id
)
self
.
fairseq_ids_to_tokens
=
{
v
:
k
for
k
,
v
in
self
.
fairseq_tokens_to_ids
.
items
()}
self
.
fairseq_ids_to_tokens
=
{
v
:
k
for
k
,
v
in
self
.
fairseq_tokens_to_ids
.
items
()}
...
...
tests/test_modeling_mbart.py
View file @
5abe5038
...
@@ -123,6 +123,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
...
@@ -123,6 +123,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
@
require_torch
class
MBartCC25IntegrationTest
(
AbstractMBartIntegrationTest
):
class
MBartCC25IntegrationTest
(
AbstractMBartIntegrationTest
):
checkpoint_name
=
"facebook/mbart-large-cc25"
checkpoint_name
=
"facebook/mbart-large-cc25"
src_text
=
[
src_text
=
[
...
@@ -140,3 +141,14 @@ class MBartCC25IntegrationTest(AbstractMBartIntegrationTest):
...
@@ -140,3 +141,14 @@ class MBartCC25IntegrationTest(AbstractMBartIntegrationTest):
)
)
decoded
=
self
.
tokenizer
.
batch_decode
(
translated_tokens
,
skip_special_tokens
=
True
)
decoded
=
self
.
tokenizer
.
batch_decode
(
translated_tokens
,
skip_special_tokens
=
True
)
self
.
assertEqual
(
self
.
tgt_text
[
0
],
decoded
[
0
])
self
.
assertEqual
(
self
.
tgt_text
[
0
],
decoded
[
0
])
@
slow
def
test_fill_mask
(
self
):
inputs
=
self
.
tokenizer
.
prepare_translation_batch
([
"One of the best <mask> I ever read!"
]).
to
(
torch_device
)
outputs
=
self
.
model
.
generate
(
inputs
[
"input_ids"
],
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
"en_XX"
],
num_beams
=
1
)
prediction
:
str
=
self
.
tokenizer
.
batch_decode
(
outputs
,
clean_up_tokenization_spaces
=
True
,
skip_special_tokens
=
True
)[
0
]
self
.
assertEqual
(
prediction
,
"of the best books I ever read!"
)
tests/test_tokenization_mbart.py
View file @
5abe5038
import
tempfile
import
unittest
import
unittest
from
transformers
import
AutoTokenizer
,
BatchEncoding
,
MBartTokenizer
from
transformers
import
AutoTokenizer
,
BatchEncoding
,
MBartTokenizer
...
@@ -171,3 +172,13 @@ class MBartEnroIntegrationTest(unittest.TestCase):
...
@@ -171,3 +172,13 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self
.
assertEqual
(
ids
[
-
2
],
2
)
self
.
assertEqual
(
ids
[
-
2
],
2
)
self
.
assertEqual
(
ids
[
-
1
],
EN_CODE
)
self
.
assertEqual
(
ids
[
-
1
],
EN_CODE
)
self
.
assertEqual
(
len
(
ids
),
desired_max_length
)
self
.
assertEqual
(
len
(
ids
),
desired_max_length
)
def
test_mask_token
(
self
):
self
.
assertListEqual
(
self
.
tokenizer
.
convert_tokens_to_ids
([
"<mask>"
,
"ar_AR"
]),
[
250026
,
250001
])
def
test_special_tokens_unaffacted_by_save_load
(
self
):
tmpdirname
=
tempfile
.
mkdtemp
()
original_special_tokens
=
self
.
tokenizer
.
fairseq_tokens_to_ids
self
.
tokenizer
.
save_pretrained
(
tmpdirname
)
new_tok
=
MBartTokenizer
.
from_pretrained
(
tmpdirname
)
self
.
assertDictEqual
(
new_tok
.
fairseq_tokens_to_ids
,
original_special_tokens
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment