Unverified Commit e135a6c9 authored by Francisco Kurucz's avatar Francisco Kurucz Committed by GitHub
Browse files

Fix flax GPT-J-6B linking model in tests (#20556)

parent 24124709
...@@ -202,7 +202,7 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes ...@@ -202,7 +202,7 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left") tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left")
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True) inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
model = FlaxGPTJForCausalLM.from_pretrained("EleutherAI/gptj-6B") model = FlaxGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
model.do_sample = False model.do_sample = False
model.config.pad_token_id = model.config.eos_token_id model.config.pad_token_id = model.config.eos_token_id
...@@ -323,6 +323,6 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes ...@@ -323,6 +323,6 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
@tooslow @tooslow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("EleutherAI/gptj-6B") model = model_class_name.from_pretrained("EleutherAI/gpt-j-6B")
outputs = model(np.ones((1, 1))) outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs) self.assertIsNotNone(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment