Unverified Commit d6eab530 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

mbart.prepare_translation_batch: pass through kwargs (#5581)

parent 353b8f1e
...@@ -198,6 +198,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -198,6 +198,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
max_length: Optional[int] = None, max_length: Optional[int] = None,
padding: str = "longest", padding: str = "longest",
return_tensors: str = "pt", return_tensors: str = "pt",
**kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
"""Prepare a batch that can be passed directly to an instance of MBartModel. """Prepare a batch that can be passed directly to an instance of MBartModel.
Arguments: Arguments:
...@@ -207,6 +208,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -207,6 +208,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
tgt_lang: default ro_RO (romanian), the language we are translating to tgt_lang: default ro_RO (romanian), the language we are translating to
max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large* max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large*
padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest. padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest.
**kwargs: passed to self.__call__
Returns: Returns:
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask. :obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
...@@ -221,6 +223,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -221,6 +223,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
max_length=max_length, max_length=max_length,
padding=padding, padding=padding,
truncation=True, truncation=True,
**kwargs,
) )
if tgt_texts is None: if tgt_texts is None:
return model_inputs return model_inputs
...@@ -232,6 +235,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -232,6 +235,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
padding=padding, padding=padding,
max_length=max_length, max_length=max_length,
truncation=True, truncation=True,
**kwargs,
) )
for k, v in decoder_inputs.items(): for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v model_inputs[f"decoder_{k}"] = v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment