Unverified Commit a582cfce authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `GitModelIntegrationTest.test_batched_generation` device issue (#21362)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 73a2ff69
......@@ -508,9 +508,8 @@ class GitModelIntegrationTest(unittest.TestCase):
# we have to prepare `input_ids` with the same batch size as `pixel_values`
start_token_id = model.config.bos_token_id
generated_ids = model.generate(
pixel_values=pixel_values, input_ids=torch.tensor([[start_token_id], [start_token_id]]), max_length=50
)
input_ids = torch.tensor([[start_token_id], [start_token_id]], device=torch_device)
generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEquals(generated_captions, ["two cats sleeping on a pink blanket next to remotes."] * 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment