Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
43cb03a9
Unverified
Commit
43cb03a9
authored
Jul 01, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 01, 2020
Browse files
MarianTokenizer.prepare_translation_batch uses new tokenizer API (#5182)
parent
13deb95a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
9 deletions
+27
-9
src/transformers/tokenization_marian.py
src/transformers/tokenization_marian.py
+8
-9
tests/test_tokenization_marian.py
tests/test_tokenization_marian.py
+19
-0
No files found.
src/transformers/tokenization_marian.py
View file @
43cb03a9
...
...
@@ -129,6 +129,8 @@ class MarianTokenizer(PreTrainedTokenizer):
max_length
:
Optional
[
int
]
=
None
,
pad_to_max_length
:
bool
=
True
,
return_tensors
:
str
=
"pt"
,
truncation_strategy
=
"only_first"
,
padding
=
"longest"
,
)
->
BatchEncoding
:
"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
Arguments:
...
...
@@ -147,24 +149,21 @@ class MarianTokenizer(PreTrainedTokenizer):
raise
ValueError
(
f
"found empty string in src_texts:
{
src_texts
}
"
)
self
.
current_spm
=
self
.
spm_source
src_texts
=
[
self
.
normalize
(
t
)
for
t
in
src_texts
]
# this does not appear to do much
model_inputs
:
BatchEncoding
=
self
.
batch_encode_plus
(
src_texts
,
tokenizer_kwargs
=
dict
(
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
pad_to_max_length
=
pad_to_max_length
,
truncation_strategy
=
truncation_strategy
,
padding
=
padding
,
)
model_inputs
:
BatchEncoding
=
self
(
src_texts
,
**
tokenizer_kwargs
)
if
tgt_texts
is
None
:
return
model_inputs
self
.
current_spm
=
self
.
spm_target
decoder_inputs
:
BatchEncoding
=
self
.
batch_encode_plus
(
tgt_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
pad_to_max_length
=
pad_to_max_length
,
)
decoder_inputs
:
BatchEncoding
=
self
(
tgt_texts
,
**
tokenizer_kwargs
)
for
k
,
v
in
decoder_inputs
.
items
():
model_inputs
[
f
"decoder_
{
k
}
"
]
=
v
self
.
current_spm
=
self
.
spm_source
...
...
tests/test_tokenization_marian.py
View file @
43cb03a9
...
...
@@ -24,6 +24,7 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f
from
transformers.tokenization_utils
import
BatchEncoding
from
.test_tokenization_common
import
TokenizerTesterMixin
from
.utils
import
_torch_available
SAMPLE_SP
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/test_sentencepiece.model"
)
...
...
@@ -31,6 +32,7 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
mock_tokenizer_config
=
{
"target_lang"
:
"fi"
,
"source_lang"
:
"en"
}
zh_code
=
">>zh<<"
ORG_NAME
=
"Helsinki-NLP/"
FRAMEWORK
=
"pt"
if
_torch_available
else
"tf"
class
MarianTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
...
...
@@ -72,3 +74,20 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
contents
=
[
x
.
name
for
x
in
Path
(
save_dir
).
glob
(
"*"
)]
self
.
assertIn
(
"source.spm"
,
contents
)
MarianTokenizer
.
from_pretrained
(
save_dir
)
def
test_outputs_not_longer_than_maxlen
(
self
):
tok
=
self
.
get_tokenizer
()
batch
=
tok
.
prepare_translation_batch
(
[
"I am a small frog"
*
1000
,
"I am a small frog"
],
return_tensors
=
FRAMEWORK
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertEqual
(
batch
.
input_ids
.
shape
,
(
2
,
512
))
def
test_outputs_can_be_shorter
(
self
):
tok
=
self
.
get_tokenizer
()
batch_smaller
=
tok
.
prepare_translation_batch
(
[
"I am a tiny frog"
,
"I am a small frog"
],
return_tensors
=
FRAMEWORK
)
self
.
assertIsInstance
(
batch_smaller
,
BatchEncoding
)
self
.
assertEqual
(
batch_smaller
.
input_ids
.
shape
,
(
2
,
10
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment