Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
08b59d10
Unverified
Commit
08b59d10
authored
Jun 11, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 11, 2020
Browse files
MBartTokenizer:add language codes (#3776)
parent
20451195
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
161 additions
and
38 deletions
+161
-38
src/transformers/tokenization_bart.py
src/transformers/tokenization_bart.py
+100
-0
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+61
-38
No files found.
src/transformers/tokenization_bart.py
View file @
08b59d10
...
@@ -14,8 +14,10 @@
...
@@ -14,8 +14,10 @@
# limitations under the License.
# limitations under the License.
import
logging
import
logging
from
typing
import
List
,
Optional
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_utils
import
BatchEncoding
from
.tokenization_xlm_roberta
import
XLMRobertaTokenizer
from
.tokenization_xlm_roberta
import
XLMRobertaTokenizer
...
@@ -47,6 +49,104 @@ SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-la
...
@@ -47,6 +49,104 @@ SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-la
class
MBartTokenizer
(
XLMRobertaTokenizer
):
class
MBartTokenizer
(
XLMRobertaTokenizer
):
"""
This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs.
Other tokenizer methods like encode do not work properly.
The tokenization method is <tokens> <eos> <language code>. There is no BOS token.
Examples::
from transformers import MBartTokenizer
tokenizer = MBartTokenizer.from_pretrained('mbart-large-en-ro')
example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
batch: dict = tokenizer.prepare_translation_batch(
example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
)
"""
vocab_files_names
=
{
"vocab_file"
:
"sentencepiece.bpe.model"
}
vocab_files_names
=
{
"vocab_file"
:
"sentencepiece.bpe.model"
}
max_model_input_sizes
=
{
m
:
1024
for
m
in
_all_mbart_models
}
max_model_input_sizes
=
{
m
:
1024
for
m
in
_all_mbart_models
}
pretrained_vocab_files_map
=
{
"vocab_file"
:
{
m
:
SPM_URL
for
m
in
_all_mbart_models
}}
pretrained_vocab_files_map
=
{
"vocab_file"
:
{
m
:
SPM_URL
for
m
in
_all_mbart_models
}}
lang_code_to_id
=
{
# NOTE(SS): resize embeddings will break this
"ar_AR"
:
250001
,
"cs_CZ"
:
250002
,
"de_DE"
:
250003
,
"en_XX"
:
250004
,
"es_XX"
:
250005
,
"et_EE"
:
250006
,
"fi_FI"
:
250007
,
"fr_XX"
:
250008
,
"gu_IN"
:
250009
,
"hi_IN"
:
250010
,
"it_IT"
:
250011
,
"ja_XX"
:
250012
,
"kk_KZ"
:
250013
,
"ko_KR"
:
250014
,
"lt_LT"
:
250015
,
"lv_LV"
:
250016
,
"my_MM"
:
250017
,
"ne_NP"
:
250018
,
"nl_XX"
:
250019
,
"ro_RO"
:
250020
,
"ru_RU"
:
250021
,
"si_LK"
:
250022
,
"tr_TR"
:
250023
,
"vi_VN"
:
250024
,
"zh_CN"
:
250025
,
}
cur_lang_code
=
lang_code_to_id
[
"en_XX"
]
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
,
token_ids_1
=
None
)
->
List
[
int
]:
"""Build model inputs from a sequence by appending eos_token_id."""
special_tokens
=
[
self
.
eos_token_id
,
self
.
cur_lang_code
]
if
token_ids_1
is
None
:
return
token_ids_0
+
special_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return
token_ids_0
+
token_ids_1
+
special_tokens
def
prepare_translation_batch
(
self
,
src_texts
:
List
[
str
],
src_lang
:
str
=
"en_XX"
,
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
tgt_lang
:
str
=
"ro_RO"
,
max_length
:
Optional
[
int
]
=
None
,
pad_to_max_length
:
bool
=
True
,
return_tensors
:
str
=
"pt"
,
)
->
BatchEncoding
:
"""
Arguments:
src_texts: list of src language texts
src_lang: default en_XX (english)
tgt_texts: list of tgt language texts
tgt_lang: default ro_RO (romanian)
max_length: (None) defer to config (1024 for mbart-large-en-ro)
pad_to_max_length: (bool)
Returns:
dict with keys input_ids, attention_mask, decoder_input_ids, each value is a torch.Tensor.
"""
if
max_length
is
None
:
max_length
=
self
.
max_len
self
.
cur_lang_code
=
self
.
lang_code_to_id
[
src_lang
]
model_inputs
:
BatchEncoding
=
self
.
batch_encode_plus
(
src_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
pad_to_max_length
=
pad_to_max_length
,
)
if
tgt_texts
is
None
:
return
model_inputs
self
.
cur_lang_code
=
self
.
lang_code_to_id
[
tgt_lang
]
decoder_inputs
:
BatchEncoding
=
self
.
batch_encode_plus
(
tgt_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
pad_to_max_length
=
pad_to_max_length
,
)
for
k
,
v
in
decoder_inputs
.
items
():
model_inputs
[
f
"decoder_
{
k
}
"
]
=
v
self
.
cur_lang_code
=
self
.
lang_code_to_id
[
src_lang
]
return
model_inputs
tests/test_modeling_bart.py
View file @
08b59d10
...
@@ -19,6 +19,7 @@ import unittest
...
@@ -19,6 +19,7 @@ import unittest
import
timeout_decorator
# noqa
import
timeout_decorator
# noqa
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
transformers.file_utils
import
cached_property
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
...
@@ -37,6 +38,7 @@ if is_torch_available():
...
@@ -37,6 +38,7 @@ if is_torch_available():
BartConfig
,
BartConfig
,
BartTokenizer
,
BartTokenizer
,
MBartTokenizer
,
MBartTokenizer
,
BatchEncoding
,
)
)
from
transformers.modeling_bart
import
(
from
transformers.modeling_bart
import
(
BART_PRETRAINED_MODEL_ARCHIVE_LIST
,
BART_PRETRAINED_MODEL_ARCHIVE_LIST
,
...
@@ -197,15 +199,37 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -197,15 +199,37 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
tiny
(
**
inputs_dict
)
tiny
(
**
inputs_dict
)
EN_CODE
=
250004
@
require_torch
@
require_torch
class
BartTranslationTests
(
unittest
.
TestCase
):
class
MBartIntegrationTests
(
unittest
.
TestCase
):
_model
=
None
src_text
=
[
" UN Chief Says There Is No Military Solution in Syria"
,
" I ate lunch twice yesterday"
,
]
tgt_text
=
[
"Şeful ONU declară că nu există o soluţie militară în Siria"
,
"to be padded"
]
expected_src_tokens
=
[
8274
,
127873
,
25916
,
7
,
8622
,
2071
,
438
,
67485
,
53
,
187895
,
23
,
51712
,
2
,
EN_CODE
]
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
checkpoint_name
=
"facebook/mbart-large-en-ro"
checkpoint_name
=
"facebook/mbart-large-en-ro"
cls
.
tokenizer
=
MBartTokenizer
.
from_pretrained
(
checkpoint_name
)
cls
.
tokenizer
=
MBartTokenizer
.
from_pretrained
(
checkpoint_name
)
cls
.
pad_token_id
=
1
cls
.
pad_token_id
=
1
return
cls
@
cached_property
def
model
(
self
):
"""Only load the model if needed."""
model
=
BartForConditionalGeneration
.
from_pretrained
(
"facebook/mbart-large-en-ro"
).
to
(
torch_device
)
if
"cuda"
in
torch_device
:
model
=
model
.
half
()
return
model
@
slow
def
test_enro_forward
(
self
):
model
=
self
.
model
net_input
=
{
net_input
=
{
"input_ids"
:
_long_tensor
(
"input_ids"
:
_long_tensor
(
[
[
...
@@ -221,24 +245,9 @@ class BartTranslationTests(unittest.TestCase):
...
@@ -221,24 +245,9 @@ class BartTranslationTests(unittest.TestCase):
),
),
"generation_mode"
:
False
,
"generation_mode"
:
False
,
}
}
net_input
[
"attention_mask"
]
=
net_input
[
"input_ids"
].
ne
(
cls
.
pad_token_id
)
net_input
[
"attention_mask"
]
=
net_input
[
"input_ids"
].
ne
(
self
.
pad_token_id
)
cls
.
net_input
=
net_input
return
cls
@
property
def
model
(
self
):
"""Only load the model if needed."""
if
self
.
_model
is
None
:
model
=
BartForConditionalGeneration
.
from_pretrained
(
"facebook/mbart-large-en-ro"
)
self
.
_model
=
model
.
to
(
torch_device
)
return
self
.
_model
@
slow
def
test_enro_forward
(
self
):
model
=
self
.
model
with
torch
.
no_grad
():
with
torch
.
no_grad
():
logits
,
*
other_stuff
=
model
(
**
self
.
net_input
)
logits
,
*
other_stuff
=
model
(
**
net_input
)
expected_slice
=
torch
.
tensor
([
9.0078
,
10.1113
,
14.4787
],
device
=
torch_device
)
expected_slice
=
torch
.
tensor
([
9.0078
,
10.1113
,
14.4787
],
device
=
torch_device
)
result_slice
=
logits
[
0
][
0
][:
3
]
result_slice
=
logits
[
0
][
0
][:
3
]
...
@@ -246,19 +255,10 @@ class BartTranslationTests(unittest.TestCase):
...
@@ -246,19 +255,10 @@ class BartTranslationTests(unittest.TestCase):
@
slow
@
slow
def
test_enro_generate
(
self
):
def
test_enro_generate
(
self
):
model
=
self
.
model
inputs
:
dict
=
self
.
tokenizer
.
prepare_translation_batch
([
self
.
src_text
[
0
]]).
to
(
torch_device
)
# example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
translated_tokens
=
self
.
model
.
generate
(
input_ids
=
inputs
[
"input_ids"
].
to
(
torch_device
))
# inputs: dict = tokenizer.batch_encode_plus([example_english_phrase], return_tensors="pt",)
expected_translation_romanian
=
"Şeful ONU declară că nu există o soluţie militară în Siria"
inputs
=
{
"input_ids"
:
torch
.
LongTensor
(
[[
8274
,
127873
,
25916
,
7
,
8622
,
2071
,
438
,
67485
,
53
,
187895
,
23
,
51712
,
2
]]
# 250004
)
}
translated_tokens
=
model
.
generate
(
input_ids
=
inputs
[
"input_ids"
].
to
(
torch_device
),
num_beams
=
5
,)
decoded
=
self
.
tokenizer
.
batch_decode
(
translated_tokens
,
skip_special_tokens
=
True
)
decoded
=
self
.
tokenizer
.
batch_decode
(
translated_tokens
,
skip_special_tokens
=
True
)
self
.
assertEqual
(
expected_translation_romanian
,
decoded
[
0
])
self
.
assertEqual
(
self
.
tgt_text
[
0
]
,
decoded
[
0
])
def
test_mbart_enro_config
(
self
):
def
test_mbart_enro_config
(
self
):
mbart_models
=
[
"facebook/mbart-large-en-ro"
]
mbart_models
=
[
"facebook/mbart-large-en-ro"
]
...
@@ -273,13 +273,6 @@ class BartTranslationTests(unittest.TestCase):
...
@@ -273,13 +273,6 @@ class BartTranslationTests(unittest.TestCase):
e
.
args
+=
(
name
,
k
)
e
.
args
+=
(
name
,
k
)
raise
raise
def
test_enro_tokenizer
(
self
):
raw
=
"UN Chief Says There Is No Military Solution in Syria"
ids
=
self
.
tokenizer
.
batch_encode_plus
([
raw
])[
"input_ids"
][
0
]
expected_result
=
[
0
,
8274
,
127873
,
25916
,
7
,
8622
,
2071
,
438
,
67485
,
53
,
187895
,
23
,
51712
,
2
]
# TODO(SS): should be [8274, ..., 2, 250020]
self
.
assertListEqual
(
expected_result
,
ids
)
def
test_mbart_fast_forward
(
self
):
def
test_mbart_fast_forward
(
self
):
config
=
BartConfig
(
config
=
BartConfig
(
vocab_size
=
99
,
vocab_size
=
99
,
...
@@ -301,6 +294,36 @@ class BartTranslationTests(unittest.TestCase):
...
@@ -301,6 +294,36 @@ class BartTranslationTests(unittest.TestCase):
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
@
require_torch
class
MBartTokenizerTests
(
MBartIntegrationTests
):
def
test_enro_tokenizer_prepare_translation_batch
(
self
):
batch
=
self
.
tokenizer
.
prepare_translation_batch
(
self
.
src_text
,
tgt_texts
=
self
.
tgt_text
,
max_length
=
len
(
self
.
expected_src_tokens
),
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertEqual
((
2
,
14
),
batch
.
input_ids
.
shape
)
self
.
assertEqual
((
2
,
14
),
batch
.
attention_mask
.
shape
)
result
=
batch
.
input_ids
.
tolist
()[
0
]
self
.
assertListEqual
(
self
.
expected_src_tokens
,
result
)
self
.
assertEqual
(
2
,
batch
.
decoder_input_ids
[
0
,
-
2
])
# EOS
def
test_enro_tokenizer_batch_encode_plus
(
self
):
ids
=
self
.
tokenizer
.
batch_encode_plus
(
self
.
src_text
).
input_ids
[
0
]
self
.
assertListEqual
(
self
.
expected_src_tokens
,
ids
)
def
test_enro_tokenizer_truncation
(
self
):
src_text
=
[
"this is gunna be a long sentence "
*
20
]
assert
isinstance
(
src_text
[
0
],
str
)
desired_max_length
=
10
ids
=
self
.
tokenizer
.
prepare_translation_batch
(
src_text
,
return_tensors
=
None
,
max_length
=
desired_max_length
).
input_ids
[
0
]
self
.
assertEqual
(
ids
[
-
2
],
2
)
self
.
assertEqual
(
ids
[
-
1
],
EN_CODE
)
self
.
assertEqual
(
len
(
ids
),
desired_max_length
)
@
require_torch
@
require_torch
class
BartHeadTests
(
unittest
.
TestCase
):
class
BartHeadTests
(
unittest
.
TestCase
):
vocab_size
=
99
vocab_size
=
99
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment