Unverified Commit 18058574 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generation] Generation should allow to start with empty prompt (#3993)

* fix empty prompt

* fix length in generation pipeline
parent 52679fbc
......@@ -221,8 +221,13 @@ def main():
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(args.device)
if encoded_prompt.size()[-1] == 0:
input_ids = None
else:
input_ids = encoded_prompt
output_sequences = model.generate(
input_ids=encoded_prompt,
input_ids=input_ids,
max_length=args.length + len(encoded_prompt[0]),
temperature=args.temperature,
top_k=args.k,
......
......@@ -563,14 +563,19 @@ class TextGenerationPipeline(Pipeline):
else:
inputs = self._parse_and_tokenize(prompt_text)
if self.framework == "pt":
# set input_ids to None to allow empty prompt
if inputs["input_ids"].shape[-1] == 0:
inputs["input_ids"] = None
inputs["attention_mask"] = None
if self.framework == "pt" and inputs["input_ids"] is not None:
inputs = self.ensure_tensor_on_device(**inputs)
input_ids = inputs["input_ids"]
# Ensure that batch size = 1 (batch generation not allowed for now)
assert (
input_ids.shape[0] == 1
input_ids is None or input_ids.shape[0] == 1
), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."
output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
......@@ -590,19 +595,19 @@ class TextGenerationPipeline(Pipeline):
)
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
record["generated_text"] = (
prompt_text
+ text[
len(
if input_ids is None:
prompt_length = 0
else:
prompt_length = len(
self.tokenizer.decode(
input_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
) :
]
)
record["generated_text"] = prompt_text + text[prompt_length:]
result.append(record)
results += [result]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment