Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3d495c61
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6fe8a693ebbfa6e70b880f7c24e0cf524be6fb25"
Unverified
Commit
3d495c61
authored
Jun 16, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 16, 2020
Browse files
Fix marian tokenizer save pretrained (#5043)
parent
d5477baf
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
7 deletions
+14
-7
src/transformers/tokenization_marian.py
src/transformers/tokenization_marian.py
+7
-5
tests/test_tokenization_marian.py
tests/test_tokenization_marian.py
+7
-2
No files found.
src/transformers/tokenization_marian.py
View file @
3d495c61
...
@@ -40,9 +40,9 @@ class MarianTokenizer(PreTrainedTokenizer):
...
@@ -40,9 +40,9 @@ class MarianTokenizer(PreTrainedTokenizer):
def
__init__
(
def
__init__
(
self
,
self
,
vocab
=
None
,
vocab
,
source_spm
=
None
,
source_spm
,
target_spm
=
None
,
target_spm
,
source_lang
=
None
,
source_lang
=
None
,
target_lang
=
None
,
target_lang
=
None
,
unk_token
=
"<unk>"
,
unk_token
=
"<unk>"
,
...
@@ -59,6 +59,7 @@ class MarianTokenizer(PreTrainedTokenizer):
...
@@ -59,6 +59,7 @@ class MarianTokenizer(PreTrainedTokenizer):
pad_token
=
pad_token
,
pad_token
=
pad_token
,
**
kwargs
,
**
kwargs
,
)
)
assert
Path
(
source_spm
).
exists
(),
f
"cannot find spm source
{
source_spm
}
"
self
.
encoder
=
load_json
(
vocab
)
self
.
encoder
=
load_json
(
vocab
)
if
self
.
unk_token
not
in
self
.
encoder
:
if
self
.
unk_token
not
in
self
.
encoder
:
raise
KeyError
(
"<unk> token must be in vocab"
)
raise
KeyError
(
"<unk> token must be in vocab"
)
...
@@ -179,10 +180,11 @@ class MarianTokenizer(PreTrainedTokenizer):
...
@@ -179,10 +180,11 @@ class MarianTokenizer(PreTrainedTokenizer):
assert
save_dir
.
is_dir
(),
f
"
{
save_directory
}
should be a directory"
assert
save_dir
.
is_dir
(),
f
"
{
save_directory
}
should be a directory"
save_json
(
self
.
encoder
,
save_dir
/
self
.
vocab_files_names
[
"vocab"
])
save_json
(
self
.
encoder
,
save_dir
/
self
.
vocab_files_names
[
"vocab"
])
for
f
in
self
.
spm_files
:
for
orig
,
f
in
zip
([
"source.spm"
,
"target.spm"
],
self
.
spm_files
)
:
dest_path
=
save_dir
/
Path
(
f
).
name
dest_path
=
save_dir
/
Path
(
f
).
name
if
not
dest_path
.
exists
():
if
not
dest_path
.
exists
():
copyfile
(
f
,
save_dir
/
Path
(
f
).
name
)
copyfile
(
f
,
save_dir
/
orig
)
return
tuple
(
save_dir
/
f
for
f
in
self
.
vocab_files_names
)
return
tuple
(
save_dir
/
f
for
f
in
self
.
vocab_files_names
)
def
get_vocab
(
self
)
->
Dict
:
def
get_vocab
(
self
)
->
Dict
:
...
...
tests/test_tokenization_marian.py
View file @
3d495c61
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
os
import
os
import
tempfile
import
unittest
import
unittest
from
pathlib
import
Path
from
pathlib
import
Path
from
shutil
import
copyfile
from
shutil
import
copyfile
...
@@ -23,7 +24,6 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f
...
@@ -23,7 +24,6 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f
from
transformers.tokenization_utils
import
BatchEncoding
from
transformers.tokenization_utils
import
BatchEncoding
from
.test_tokenization_common
import
TokenizerTesterMixin
from
.test_tokenization_common
import
TokenizerTesterMixin
from
.utils
import
slow
SAMPLE_SP
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/test_sentencepiece.model"
)
SAMPLE_SP
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/test_sentencepiece.model"
)
...
@@ -60,10 +60,15 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -60,10 +60,15 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"This is a test"
,
"This is a test"
,
)
)
@
slow
def
test_tokenizer_equivalence_en_de
(
self
):
def
test_tokenizer_equivalence_en_de
(
self
):
en_de_tokenizer
=
MarianTokenizer
.
from_pretrained
(
f
"
{
ORG_NAME
}
opus-mt-en-de"
)
en_de_tokenizer
=
MarianTokenizer
.
from_pretrained
(
f
"
{
ORG_NAME
}
opus-mt-en-de"
)
batch
=
en_de_tokenizer
.
prepare_translation_batch
([
"I am a small frog"
],
return_tensors
=
None
)
batch
=
en_de_tokenizer
.
prepare_translation_batch
([
"I am a small frog"
],
return_tensors
=
None
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
expected
=
[
38
,
121
,
14
,
697
,
38848
,
0
]
expected
=
[
38
,
121
,
14
,
697
,
38848
,
0
]
self
.
assertListEqual
(
expected
,
batch
.
input_ids
[
0
])
self
.
assertListEqual
(
expected
,
batch
.
input_ids
[
0
])
save_dir
=
tempfile
.
mkdtemp
()
en_de_tokenizer
.
save_pretrained
(
save_dir
)
contents
=
[
x
.
name
for
x
in
Path
(
save_dir
).
glob
(
"*"
)]
self
.
assertIn
(
"source.spm"
,
contents
)
MarianTokenizer
.
from_pretrained
(
save_dir
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment