Unverified Commit efae6645 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix `xxx_length` behavior when using XLNet in pipeline (#5319)

parent 393b8dc0
......@@ -586,7 +586,7 @@ class TextGenerationPipeline(Pipeline):
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
with people, even a bishop, begging for his blessing. """
ALLOWED_MODELS = [
"XLNetLMHeadModel",
......@@ -619,8 +619,18 @@ class TextGenerationPipeline(Pipeline):
# Manage correct placement of the tensors
with self.device_placement():
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]:
# For XLNet and TransformerXL we had an article to the prompt to give more state to the model.
padding_text = self.PADDING_TEXT + self.tokenizer.eos_token
padding = self._parse_and_tokenize(padding_text, padding=False, add_special_tokens=False)
# This impacts max_length and min_length argument that need adjusting.
padding_length = padding["input_ids"].shape[-1]
if "max_length" in generate_kwargs and generate_kwargs["max_length"] is not None:
generate_kwargs["max_length"] += padding_length
if "min_length" in generate_kwargs and generate_kwargs["min_length"] is not None:
generate_kwargs["min_length"] += padding_length
inputs = self._parse_and_tokenize(
self.PADDING_TEXT + prompt_text, padding=False, add_special_tokens=False
padding_text + prompt_text, padding=False, add_special_tokens=False
)
else:
inputs = self._parse_and_tokenize(prompt_text, padding=False, add_special_tokens=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment