Commit b370cc7e authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[gpu] Fixup fdd61b19

parent f5516805
......@@ -205,7 +205,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape).long()
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = self.seq_length // 2
attn_mask[:, half_seq_length:] = 0
......@@ -222,7 +222,9 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat([attn_mask, torch.ones((attn_mask.shape[0], 1)).long()], dim=1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1
)
# get two different outputs
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment