Unverified Commit 8fff61b9 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix failing `test_batch_generation` for bloom (#25718)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent f01459c7
...@@ -449,9 +449,9 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -449,9 +449,9 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
input_sentence = ["I enjoy walking with my cute dog", "I enjoy walking with my cute dog"] input_sentence = ["I enjoy walking with my cute dog", "I enjoy walking with my cute dog"]
input_ids = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True) inputs = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
input_ids = input_ids["input_ids"].to(torch_device) input_ids = inputs["input_ids"].to(torch_device)
attention_mask = input_ids["attention_mask"] attention_mask = inputs["attention_mask"]
greedy_output = model.generate(input_ids, attention_mask=attention_mask, max_length=50, do_sample=False) greedy_output = model.generate(input_ids, attention_mask=attention_mask, max_length=50, do_sample=False)
self.assertEqual( self.assertEqual(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment