Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
43cb03a9
Unverified
Commit
43cb03a9
authored
Jul 01, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 01, 2020
Browse files
MarianTokenizer.prepare_translation_batch uses new tokenizer API (#5182)
parent
13deb95a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
9 deletions
+27
-9
src/transformers/tokenization_marian.py
src/transformers/tokenization_marian.py
+8
-9
tests/test_tokenization_marian.py
tests/test_tokenization_marian.py
+19
-0
No files found.
src/transformers/tokenization_marian.py
View file @
43cb03a9
...
...
@@ -129,6 +129,8 @@ class MarianTokenizer(PreTrainedTokenizer):
max_length
:
Optional
[
int
]
=
None
,
pad_to_max_length
:
bool
=
True
,
return_tensors
:
str
=
"pt"
,
truncation_strategy
=
"only_first"
,
padding
=
"longest"
,
)
->
BatchEncoding
:
"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
Arguments:
...
...
@@ -147,24 +149,21 @@ class MarianTokenizer(PreTrainedTokenizer):
raise
ValueError
(
f
"found empty string in src_texts:
{
src_texts
}
"
)
self
.
current_spm
=
self
.
spm_source
src_texts
=
[
self
.
normalize
(
t
)
for
t
in
src_texts
]
# this does not appear to do much
model_inputs
:
BatchEncoding
=
self
.
batch_encode_plus
(
src_texts
,
tokenizer_kwargs
=
dict
(
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
pad_to_max_length
=
pad_to_max_length
,
truncation_strategy
=
truncation_strategy
,
padding
=
padding
,
)
model_inputs
:
BatchEncoding
=
self
(
src_texts
,
**
tokenizer_kwargs
)
if
tgt_texts
is
None
:
return
model_inputs
self
.
current_spm
=
self
.
spm_target
decoder_inputs
:
BatchEncoding
=
self
.
batch_encode_plus
(
tgt_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
pad_to_max_length
=
pad_to_max_length
,
)
decoder_inputs
:
BatchEncoding
=
self
(
tgt_texts
,
**
tokenizer_kwargs
)
for
k
,
v
in
decoder_inputs
.
items
():
model_inputs
[
f
"decoder_
{
k
}
"
]
=
v
self
.
current_spm
=
self
.
spm_source
...
...
tests/test_tokenization_marian.py
View file @
43cb03a9
...
...
@@ -24,6 +24,7 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f
from
transformers.tokenization_utils
import
BatchEncoding
from
.test_tokenization_common
import
TokenizerTesterMixin
from
.utils
import
_torch_available
SAMPLE_SP
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/test_sentencepiece.model"
)
...
...
@@ -31,6 +32,7 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
mock_tokenizer_config
=
{
"target_lang"
:
"fi"
,
"source_lang"
:
"en"
}
zh_code
=
">>zh<<"
ORG_NAME
=
"Helsinki-NLP/"
FRAMEWORK
=
"pt"
if
_torch_available
else
"tf"
class
MarianTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
...
...
@@ -72,3 +74,20 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
contents
=
[
x
.
name
for
x
in
Path
(
save_dir
).
glob
(
"*"
)]
self
.
assertIn
(
"source.spm"
,
contents
)
MarianTokenizer
.
from_pretrained
(
save_dir
)
def
test_outputs_not_longer_than_maxlen
(
self
):
tok
=
self
.
get_tokenizer
()
batch
=
tok
.
prepare_translation_batch
(
[
"I am a small frog"
*
1000
,
"I am a small frog"
],
return_tensors
=
FRAMEWORK
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertEqual
(
batch
.
input_ids
.
shape
,
(
2
,
512
))
def
test_outputs_can_be_shorter
(
self
):
tok
=
self
.
get_tokenizer
()
batch_smaller
=
tok
.
prepare_translation_batch
(
[
"I am a tiny frog"
,
"I am a small frog"
],
return_tensors
=
FRAMEWORK
)
self
.
assertIsInstance
(
batch_smaller
,
BatchEncoding
)
self
.
assertEqual
(
batch_smaller
.
input_ids
.
shape
,
(
2
,
10
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment