"examples/legacy/seq2seq/run_eval.py" did not exist on "b76cb1c3dfc64d1dcaddc3d6d9313dddeb626d05"
Unverified Commit 54192058 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Test OPT] Add batch generation test opt (#17359)

* up

* up
parent 48c22691
...@@ -366,6 +366,47 @@ class OPTGenerationTest(unittest.TestCase): ...@@ -366,6 +366,47 @@ class OPTGenerationTest(unittest.TestCase):
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS) self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
def test_batch_generation(self):
model_id = "facebook/opt-350m"
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
model = OPTForCausalLM.from_pretrained(model_id)
model.to(torch_device)
tokenizer.padding_side = "left"
# use different length sentences to test batching
sentences = [
"Hello, my dog is a little",
"Today, I",
]
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(torch_device)
outputs = model.generate(
input_ids=input_ids,
attention_mask=inputs["attention_mask"].to(torch_device),
)
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
output_non_padded = model.generate(input_ids=inputs_non_padded)
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
expected_output_sentence = [
"Hello, my dog is a little bit of a dork.\nI'm a little bit",
"Today, I was in the middle of a conversation with a friend about the",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
def test_generation_post_attn_layer_norm(self): def test_generation_post_attn_layer_norm(self):
model_id = "facebook/opt-350m" model_id = "facebook/opt-350m"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment