Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
efbc1c5a
Unverified
Commit
efbc1c5a
authored
May 19, 2020
by
Sam Shleifer
Committed by
GitHub
May 19, 2020
Browse files
[MarianTokenizer] implement save_vocabulary and other common methods (#4389)
parent
956c4c4e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
145 additions
and
15 deletions
+145
-15
src/transformers/tokenization_marian.py
src/transformers/tokenization_marian.py
+75
-10
tests/test_modeling_marian.py
tests/test_modeling_marian.py
+0
-5
tests/test_tokenization_marian.py
tests/test_tokenization_marian.py
+70
-0
No files found.
src/transformers/tokenization_marian.py
View file @
efbc1c5a
import
json
import
re
import
warnings
from
typing
import
Dict
,
List
,
Optional
,
Union
from
pathlib
import
Path
from
shutil
import
copyfile
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
sentencepiece
...
...
@@ -15,7 +17,7 @@ vocab_files_names = {
"vocab"
:
"vocab.json"
,
"tokenizer_config_file"
:
"tokenizer_config.json"
,
}
MODEL_NAMES
=
(
"opus-mt-en-de"
,)
# TODO(SS): the only required constant is vocab_files_names
MODEL_NAMES
=
(
"opus-mt-en-de"
,)
# TODO(SS):
delete this,
the only required constant is vocab_files_names
PRETRAINED_VOCAB_FILES_MAP
=
{
k
:
{
m
:
f
"
{
S3_BUCKET_PREFIX
}
/Helsinki-NLP/
{
m
}
/
{
fname
}
"
for
m
in
MODEL_NAMES
}
for
k
,
fname
in
vocab_files_names
.
items
()
...
...
@@ -55,14 +57,16 @@ class MarianTokenizer(PreTrainedTokenizer):
eos_token
=
"</s>"
,
pad_token
=
"<pad>"
,
max_len
=
512
,
**
kwargs
,
):
super
().
__init__
(
# bos_token=bos_token,
# bos_token=bos_token,
unused. Start decoding with config.decoder_start_token_id
max_len
=
max_len
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
pad_token
=
pad_token
,
**
kwargs
,
)
self
.
encoder
=
load_json
(
vocab
)
if
self
.
unk_token
not
in
self
.
encoder
:
...
...
@@ -72,21 +76,23 @@ class MarianTokenizer(PreTrainedTokenizer):
self
.
source_lang
=
source_lang
self
.
target_lang
=
target_lang
self
.
supported_language_codes
:
list
=
[
k
for
k
in
self
.
encoder
if
k
.
startswith
(
">>"
)
and
k
.
endswith
(
"<<"
)]
self
.
spm_files
=
[
source_spm
,
target_spm
]
# load SentencePiece model for pre-processing
self
.
spm_source
=
sentencepiece
.
SentencePieceProcessor
()
self
.
spm_source
.
Load
(
source_spm
)
self
.
spm_target
=
sentencepiece
.
SentencePieceProcessor
()
self
.
spm_target
.
Load
(
target_spm
)
self
.
spm_source
=
load_spm
(
source_spm
)
self
.
spm_target
=
load_spm
(
target_spm
)
self
.
current_spm
=
self
.
spm_source
# Multilingual target side: default to using first supported language code.
self
.
supported_language_codes
:
list
=
[
k
for
k
in
self
.
encoder
if
k
.
startswith
(
">>"
)
and
k
.
endswith
(
"<<"
)]
self
.
_setup_normalizer
()
def
_setup_normalizer
(
self
):
try
:
from
mosestokenizer
import
MosesPunctuationNormalizer
self
.
punc_normalizer
=
MosesPunctuationNormalizer
(
source_lang
)
self
.
punc_normalizer
=
MosesPunctuationNormalizer
(
self
.
source_lang
)
except
ImportError
:
warnings
.
warn
(
"Recommended: pip install mosestokenizer"
)
self
.
punc_normalizer
=
lambda
x
:
x
...
...
@@ -176,6 +182,65 @@ class MarianTokenizer(PreTrainedTokenizer):
def
vocab_size
(
self
)
->
int
:
return
len
(
self
.
encoder
)
def
save_vocabulary
(
self
,
save_directory
:
str
)
->
Tuple
[
str
]:
"""save vocab file to json and copy spm files from their original path."""
save_dir
=
Path
(
save_directory
)
assert
save_dir
.
is_dir
(),
f
"
{
save_directory
}
should be a directory"
save_json
(
self
.
encoder
,
save_dir
/
self
.
vocab_files_names
[
"vocab"
])
for
f
in
self
.
spm_files
:
dest_path
=
save_dir
/
Path
(
f
).
name
if
not
dest_path
.
exists
():
copyfile
(
f
,
save_dir
/
Path
(
f
).
name
)
return
tuple
(
save_dir
/
f
for
f
in
self
.
vocab_files_names
)
def
get_vocab
(
self
)
->
Dict
:
vocab
=
self
.
encoder
.
copy
()
vocab
.
update
(
self
.
added_tokens_encoder
)
return
vocab
def
__getstate__
(
self
)
->
Dict
:
state
=
self
.
__dict__
.
copy
()
state
.
update
({
k
:
None
for
k
in
[
"spm_source"
,
"spm_target"
,
"current_spm"
,
"punc_normalizer"
]})
return
state
def
__setstate__
(
self
,
d
:
Dict
)
->
None
:
self
.
__dict__
=
d
self
.
spm_source
,
self
.
spm_target
=
(
load_spm
(
f
)
for
f
in
self
.
spm_files
)
self
.
current_spm
=
self
.
spm_source
self
.
_setup_normalizer
()
def
num_special_tokens_to_add
(
self
,
**
unused
):
"""Just EOS"""
return
1
def
_special_token_mask
(
self
,
seq
):
all_special_ids
=
set
(
self
.
all_special_ids
)
# call it once instead of inside list comp
all_special_ids
.
remove
(
self
.
unk_token_id
)
# <unk> is only sometimes special
return
[
1
if
x
in
all_special_ids
else
0
for
x
in
seq
]
def
get_special_tokens_mask
(
self
,
token_ids_0
:
List
,
token_ids_1
:
Optional
[
List
]
=
None
,
already_has_special_tokens
:
bool
=
False
)
->
List
[
int
]:
"""Get list where entries are [1] if a token is [eos] or [pad] else 0."""
if
already_has_special_tokens
:
return
self
.
_special_token_mask
(
token_ids_0
)
elif
token_ids_1
is
None
:
return
self
.
_special_token_mask
(
token_ids_0
)
+
[
1
]
else
:
return
self
.
_special_token_mask
(
token_ids_0
+
token_ids_1
)
+
[
1
]
def
load_spm
(
path
:
str
)
->
sentencepiece
.
SentencePieceProcessor
:
spm
=
sentencepiece
.
SentencePieceProcessor
()
spm
.
Load
(
path
)
return
spm
def
save_json
(
data
,
path
:
str
)
->
None
:
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
data
,
f
,
indent
=
2
)
def
load_json
(
path
:
str
)
->
Union
[
Dict
,
List
]:
with
open
(
path
,
"r"
)
as
f
:
...
...
tests/test_modeling_marian.py
View file @
efbc1c5a
...
...
@@ -129,11 +129,6 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
max_indices
=
logits
.
argmax
(
-
1
)
self
.
tokenizer
.
batch_decode
(
max_indices
)
def
test_tokenizer_equivalence
(
self
):
batch
=
self
.
tokenizer
.
prepare_translation_batch
([
"I am a small frog"
]).
to
(
torch_device
)
expected
=
[
38
,
121
,
14
,
697
,
38848
,
0
]
self
.
assertListEqual
(
expected
,
batch
.
input_ids
[
0
].
tolist
())
def
test_unk_support
(
self
):
t
=
self
.
tokenizer
ids
=
t
.
prepare_translation_batch
([
"||"
]).
to
(
torch_device
).
input_ids
[
0
].
tolist
()
...
...
tests/test_tokenization_marian.py
0 → 100644
View file @
efbc1c5a
# coding=utf-8
# Copyright 2020 Huggingface
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
unittest
from
pathlib
import
Path
from
shutil
import
copyfile
from
transformers.tokenization_marian
import
MarianTokenizer
,
save_json
,
vocab_files_names
from
transformers.tokenization_utils
import
BatchEncoding
from
.test_tokenization_common
import
TokenizerTesterMixin
from
.utils
import
slow
SAMPLE_SP
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/test_sentencepiece.model"
)
mock_tokenizer_config
=
{
"target_lang"
:
"fi"
,
"source_lang"
:
"en"
}
zh_code
=
">>zh<<"
ORG_NAME
=
"Helsinki-NLP/"
class
MarianTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
MarianTokenizer
def
setUp
(
self
):
super
().
setUp
()
vocab
=
[
"</s>"
,
"<unk>"
,
"▁This"
,
"▁is"
,
"▁a"
,
"▁t"
,
"est"
,
"
\u0120
"
,
"<pad>"
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
save_dir
=
Path
(
self
.
tmpdirname
)
save_json
(
vocab_tokens
,
save_dir
/
vocab_files_names
[
"vocab"
])
save_json
(
mock_tokenizer_config
,
save_dir
/
vocab_files_names
[
"tokenizer_config_file"
])
if
not
(
save_dir
/
vocab_files_names
[
"source_spm"
]).
exists
():
copyfile
(
SAMPLE_SP
,
save_dir
/
vocab_files_names
[
"source_spm"
])
copyfile
(
SAMPLE_SP
,
save_dir
/
vocab_files_names
[
"target_spm"
])
tokenizer
=
MarianTokenizer
.
from_pretrained
(
self
.
tmpdirname
)
tokenizer
.
save_pretrained
(
self
.
tmpdirname
)
def
get_tokenizer
(
self
,
max_len
=
None
,
**
kwargs
)
->
MarianTokenizer
:
# overwrite max_len=512 default
return
MarianTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
max_len
=
max_len
,
**
kwargs
)
def
get_input_output_texts
(
self
):
return
(
"This is a test"
,
"This is a test"
,
)
@
slow
def
test_tokenizer_equivalence_en_de
(
self
):
en_de_tokenizer
=
MarianTokenizer
.
from_pretrained
(
f
"
{
ORG_NAME
}
opus-mt-en-de"
)
batch
=
en_de_tokenizer
.
prepare_translation_batch
([
"I am a small frog"
],
return_tensors
=
None
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
expected
=
[
38
,
121
,
14
,
697
,
38848
,
0
]
self
.
assertListEqual
(
expected
,
batch
.
input_ids
[
0
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment