Unverified Commit c6a510c6 authored by Teven's avatar Teven Committed by GitHub
Browse files

Fixing missing arguments for TransfoXL tokenizer when using TextGenerationPipeline (#5465)

* overriding _parse_and_tokenize in `TextGenerationPipeine` to allow for TransfoXl tokenizer arguments
parent 6726416e
...@@ -615,6 +615,28 @@ class TextGenerationPipeline(Pipeline): ...@@ -615,6 +615,28 @@ class TextGenerationPipeline(Pipeline):
"TFCTRLLMHeadModel", "TFCTRLLMHeadModel",
] ]
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
"""
Parse arguments and tokenize
"""
# Parse arguments
if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
else:
tokenizer_kwargs = {}
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer(
inputs,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
**tokenizer_kwargs,
)
return inputs
def __call__( def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment