Unverified Commit efae6645 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix `xxx_length` behavior when using XLNet in pipeline (#5319)

parent 393b8dc0
...@@ -586,7 +586,7 @@ class TextGenerationPipeline(Pipeline): ...@@ -586,7 +586,7 @@ class TextGenerationPipeline(Pipeline):
father initially slaps him for making such an accusation, Rasputin watches as the father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>""" with people, even a bishop, begging for his blessing. """
ALLOWED_MODELS = [ ALLOWED_MODELS = [
"XLNetLMHeadModel", "XLNetLMHeadModel",
...@@ -619,8 +619,18 @@ class TextGenerationPipeline(Pipeline): ...@@ -619,8 +619,18 @@ class TextGenerationPipeline(Pipeline):
# Manage correct placement of the tensors # Manage correct placement of the tensors
with self.device_placement(): with self.device_placement():
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]: if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]:
# For XLNet and TransformerXL we had an article to the prompt to give more state to the model.
padding_text = self.PADDING_TEXT + self.tokenizer.eos_token
padding = self._parse_and_tokenize(padding_text, padding=False, add_special_tokens=False)
# This impacts max_length and min_length argument that need adjusting.
padding_length = padding["input_ids"].shape[-1]
if "max_length" in generate_kwargs and generate_kwargs["max_length"] is not None:
generate_kwargs["max_length"] += padding_length
if "min_length" in generate_kwargs and generate_kwargs["min_length"] is not None:
generate_kwargs["min_length"] += padding_length
inputs = self._parse_and_tokenize( inputs = self._parse_and_tokenize(
self.PADDING_TEXT + prompt_text, padding=False, add_special_tokens=False padding_text + prompt_text, padding=False, add_special_tokens=False
) )
else: else:
inputs = self._parse_and_tokenize(prompt_text, padding=False, add_special_tokens=False) inputs = self._parse_and_tokenize(prompt_text, padding=False, add_special_tokens=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment