Unverified Commit badb9d2a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Correct naming pegasus x (#18896)

* add first generation tutorial

* [Pegasus X] correct naming

* [Generation] Remove
parent 591cfc6c
......@@ -559,7 +559,7 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
return PegasusTokenizer.from_pretrained("google/pegasus-x-base")
def test_inference_no_head(self):
model = PegasusXModel.from_pretrained("pegasus-x-base").to(torch_device)
model = PegasusXModel.from_pretrained("google/pegasus-x-base").to(torch_device)
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
decoder_input_ids = _long_tensor([[2, 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588]])
inputs_dict = prepare_pegasus_x_inputs_dict(model.config, input_ids, decoder_input_ids)
......@@ -574,7 +574,7 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
def test_inference_head(self):
model = PegasusXForConditionalGeneration.from_pretrained("pegasus-x-base").to(torch_device)
model = PegasusXForConditionalGeneration.from_pretrained("google/pegasus-x-base").to(torch_device)
# change to intended input
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment