Unverified Commit 18058574 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generation] Generation should allow to start with empty prompt (#3993)

* fix empty prompt

* fix length in generation pipeline
parent 52679fbc
...@@ -221,8 +221,13 @@ def main(): ...@@ -221,8 +221,13 @@ def main():
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(args.device) encoded_prompt = encoded_prompt.to(args.device)
if encoded_prompt.size()[-1] == 0:
input_ids = None
else:
input_ids = encoded_prompt
output_sequences = model.generate( output_sequences = model.generate(
input_ids=encoded_prompt, input_ids=input_ids,
max_length=args.length + len(encoded_prompt[0]), max_length=args.length + len(encoded_prompt[0]),
temperature=args.temperature, temperature=args.temperature,
top_k=args.k, top_k=args.k,
......
...@@ -563,14 +563,19 @@ class TextGenerationPipeline(Pipeline): ...@@ -563,14 +563,19 @@ class TextGenerationPipeline(Pipeline):
else: else:
inputs = self._parse_and_tokenize(prompt_text) inputs = self._parse_and_tokenize(prompt_text)
if self.framework == "pt": # set input_ids to None to allow empty prompt
if inputs["input_ids"].shape[-1] == 0:
inputs["input_ids"] = None
inputs["attention_mask"] = None
if self.framework == "pt" and inputs["input_ids"] is not None:
inputs = self.ensure_tensor_on_device(**inputs) inputs = self.ensure_tensor_on_device(**inputs)
input_ids = inputs["input_ids"] input_ids = inputs["input_ids"]
# Ensure that batch size = 1 (batch generation not allowed for now) # Ensure that batch size = 1 (batch generation not allowed for now)
assert ( assert (
input_ids.shape[0] == 1 input_ids is None or input_ids.shape[0] == 1
), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information." ), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."
output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
...@@ -590,19 +595,19 @@ class TextGenerationPipeline(Pipeline): ...@@ -590,19 +595,19 @@ class TextGenerationPipeline(Pipeline):
) )
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used # Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
record["generated_text"] = ( if input_ids is None:
prompt_text prompt_length = 0
+ text[ else:
len( prompt_length = len(
self.tokenizer.decode( self.tokenizer.decode(
input_ids[0], input_ids[0],
skip_special_tokens=True, skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
) )
) :
]
) )
record["generated_text"] = prompt_text + text[prompt_length:]
result.append(record) result.append(record)
results += [result] results += [result]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment