Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d6eab530
Unverified
Commit
d6eab530
authored
Jul 07, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 07, 2020
Browse files
mbart.prepare_translation_batch: pass through kwargs (#5581)
parent
353b8f1e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
0 deletions
+4
-0
src/transformers/tokenization_bart.py
src/transformers/tokenization_bart.py
+4
-0
No files found.
src/transformers/tokenization_bart.py
View file @
d6eab530
...
@@ -198,6 +198,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -198,6 +198,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
max_length
:
Optional
[
int
]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
padding
:
str
=
"longest"
,
padding
:
str
=
"longest"
,
return_tensors
:
str
=
"pt"
,
return_tensors
:
str
=
"pt"
,
**
kwargs
,
)
->
BatchEncoding
:
)
->
BatchEncoding
:
"""Prepare a batch that can be passed directly to an instance of MBartModel.
"""Prepare a batch that can be passed directly to an instance of MBartModel.
Arguments:
Arguments:
...
@@ -207,6 +208,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -207,6 +208,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
tgt_lang: default ro_RO (romanian), the language we are translating to
tgt_lang: default ro_RO (romanian), the language we are translating to
max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large*
max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large*
padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest.
padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest.
**kwargs: passed to self.__call__
Returns:
Returns:
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
...
@@ -221,6 +223,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -221,6 +223,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
max_length
=
max_length
,
max_length
=
max_length
,
padding
=
padding
,
padding
=
padding
,
truncation
=
True
,
truncation
=
True
,
**
kwargs
,
)
)
if
tgt_texts
is
None
:
if
tgt_texts
is
None
:
return
model_inputs
return
model_inputs
...
@@ -232,6 +235,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -232,6 +235,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
padding
=
padding
,
padding
=
padding
,
max_length
=
max_length
,
max_length
=
max_length
,
truncation
=
True
,
truncation
=
True
,
**
kwargs
,
)
)
for
k
,
v
in
decoder_inputs
.
items
():
for
k
,
v
in
decoder_inputs
.
items
():
model_inputs
[
f
"decoder_
{
k
}
"
]
=
v
model_inputs
[
f
"decoder_
{
k
}
"
]
=
v
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment