Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4ab74245
"src/vscode:/vscode.git/clone" did not exist on "035d58bc3aebc9a8d033de2c995263174c7741da"
Unverified
Commit
4ab74245
authored
Jun 05, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 05, 2020
Browse files
[cleanup/marian] pipelines test and new kwarg (#4812)
parent
875288b3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
6 deletions
+12
-6
src/transformers/tokenization_marian.py
src/transformers/tokenization_marian.py
+3
-4
tests/test_modeling_marian.py
tests/test_modeling_marian.py
+8
-0
tests/test_tokenization_marian.py
tests/test_tokenization_marian.py
+1
-2
No files found.
src/transformers/tokenization_marian.py
View file @
4ab74245
...
...
@@ -48,13 +48,12 @@ class MarianTokenizer(PreTrainedTokenizer):
unk_token
=
"<unk>"
,
eos_token
=
"</s>"
,
pad_token
=
"<pad>"
,
max_len
=
512
,
**
kwargs
,
model_
max_len
gth
=
512
,
**
kwargs
):
super
().
__init__
(
# bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id
m
ax_len
=
max_len
,
m
odel_max_length
=
model_
max_len
gth
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
pad_token
=
pad_token
,
...
...
tests/test_modeling_marian.py
View file @
4ab74245
...
...
@@ -38,6 +38,7 @@ if is_torch_available():
convert_opus_name_to_hf_name
,
ORG_NAME
,
)
from
transformers.pipelines
import
TranslationPipeline
class
ModelManagementTests
(
unittest
.
TestCase
):
...
...
@@ -189,6 +190,7 @@ class TestMarian_RU_FR(MarianIntegrationTest):
src_text
=
[
"Он показал мне рукопись своей новой пьесы."
]
expected_text
=
[
"Il m'a montré le manuscrit de sa nouvelle pièce."
]
@
slow
def
test_batch_generation_ru_fr
(
self
):
self
.
_assert_generated_batch_equal_expected
()
...
...
@@ -199,6 +201,7 @@ class TestMarian_MT_EN(MarianIntegrationTest):
src_text
=
[
"Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."
]
expected_text
=
[
"Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."
]
@
slow
def
test_batch_generation_mt_en
(
self
):
self
.
_assert_generated_batch_equal_expected
()
...
...
@@ -229,6 +232,11 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
with
self
.
assertRaises
(
ValueError
):
self
.
tokenizer
.
prepare_translation_batch
([
""
])
def
test_pipeline
(
self
):
pipeline
=
TranslationPipeline
(
self
.
model
,
self
.
tokenizer
,
framework
=
"pt"
)
output
=
pipeline
(
self
.
src_text
)
self
.
assertEqual
(
self
.
expected_text
,
[
x
[
"translation_text"
]
for
x
in
output
])
@
require_torch
class
TestConversionUtils
(
unittest
.
TestCase
):
...
...
tests/test_tokenization_marian.py
View file @
4ab74245
...
...
@@ -52,8 +52,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer
.
save_pretrained
(
self
.
tmpdirname
)
def
get_tokenizer
(
self
,
max_len
=
None
,
**
kwargs
)
->
MarianTokenizer
:
# overwrite max_len=512 default
return
MarianTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
max_len
=
max_len
,
**
kwargs
)
return
MarianTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
model_max_length
=
max_len
,
**
kwargs
)
def
get_input_output_texts
(
self
):
return
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment