Commit 66c82765 authored by patrickvonplaten's avatar patrickvonplaten
Browse files

fix typo in test gpt2

parent 314bdc7c
......@@ -343,7 +343,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_gpt2(self):
model = GPT2LMHeadModel.from_pretrained("gpt2")
input_ids = torch.tensor([[463, 3290]], dtype=torch.long, device=torch_device) # The dog
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [
464,
3290,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment