Unverified Commit d6eab530 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

mbart.prepare_translation_batch: pass through kwargs (#5581)

parent 353b8f1e
......@@ -198,6 +198,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
max_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = "pt",
**kwargs,
) -> BatchEncoding:
"""Prepare a batch that can be passed directly to an instance of MBartModel.
Arguments:
......@@ -207,6 +208,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
tgt_lang: default ro_RO (romanian), the language we are translating to
max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large*
padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest.
**kwargs: passed to self.__call__
Returns:
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
......@@ -221,6 +223,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
max_length=max_length,
padding=padding,
truncation=True,
**kwargs,
)
if tgt_texts is None:
return model_inputs
......@@ -232,6 +235,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
padding=padding,
max_length=max_length,
truncation=True,
**kwargs,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment