Unverified Commit 2bef3433 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Correct all return tensors to numpy (#13307)

* fix_torch_device_generate_test

* remove @

* finish find and replace
parent 8aa67fc1
...@@ -110,7 +110,7 @@ def main(): ...@@ -110,7 +110,7 @@ def main():
inputs = tokenizer( inputs = tokenizer(
example["question"], example["question"],
example["context"], example["context"],
return_tensors="jax", return_tensors="np",
max_length=4096, max_length=4096,
padding="max_length", padding="max_length",
truncation=True, truncation=True,
......
...@@ -1121,7 +1121,7 @@ FLAX_CAUSAL_LM_SAMPLE = r""" ...@@ -1121,7 +1121,7 @@ FLAX_CAUSAL_LM_SAMPLE = r"""
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}') >>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> # retrieve logts for next token >>> # retrieve logts for next token
......
...@@ -231,7 +231,7 @@ class FlaxGenerationMixin: ...@@ -231,7 +231,7 @@ class FlaxGenerationMixin:
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2") >>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
>>> input_context = "The dog" >>> input_context = "The dog"
>>> # encode input context >>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids >>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
>>> # generate candidates using sampling >>> # generate candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
......
...@@ -757,7 +757,7 @@ FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """ ...@@ -757,7 +757,7 @@ FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = FlaxBertForPreTraining.from_pretrained('bert-base-uncased') >>> model = FlaxBertForPreTraining.from_pretrained('bert-base-uncased')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> prediction_logits = outputs.prediction_logits >>> prediction_logits = outputs.prediction_logits
......
...@@ -1567,7 +1567,7 @@ FLAX_BIG_BIRD_FOR_PRETRAINING_DOCSTRING = """ ...@@ -1567,7 +1567,7 @@ FLAX_BIG_BIRD_FOR_PRETRAINING_DOCSTRING = """
>>> tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base') >>> tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base')
>>> model = FlaxBigBirdForPreTraining.from_pretrained('google/bigbird-roberta-base') >>> model = FlaxBigBirdForPreTraining.from_pretrained('google/bigbird-roberta-base')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> prediction_logits = outputs.prediction_logits >>> prediction_logits = outputs.prediction_logits
......
...@@ -761,7 +761,7 @@ FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """ ...@@ -761,7 +761,7 @@ FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """
>>> tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator') >>> tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
>>> model = FlaxElectraForPreTraining.from_pretrained('google/electra-small-discriminator') >>> model = FlaxElectraForPreTraining.from_pretrained('google/electra-small-discriminator')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits >>> prediction_logits = outputs.logits
......
...@@ -512,7 +512,7 @@ FLAX_VISION_MODEL_DOCSTRING = """ ...@@ -512,7 +512,7 @@ FLAX_VISION_MODEL_DOCSTRING = """
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
>>> model = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k') >>> model = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
>>> inputs = feature_extractor(images=image, return_tensors="jax") >>> inputs = feature_extractor(images=image, return_tensors="np")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state >>> last_hidden_states = outputs.last_hidden_state
""" """
...@@ -592,7 +592,7 @@ FLAX_VISION_CLASSIF_DOCSTRING = """ ...@@ -592,7 +592,7 @@ FLAX_VISION_CLASSIF_DOCSTRING = """
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
>>> model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224') >>> model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
>>> inputs = feature_extractor(images=image, return_tensors="jax") >>> inputs = feature_extractor(images=image, return_tensors="np")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> logits = outputs.logits >>> logits = outputs.logits
......
...@@ -453,7 +453,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT ...@@ -453,7 +453,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT
padding="max_length", padding="max_length",
truncation_strategy="only_first", truncation_strategy="only_first",
truncation=True, truncation=True,
return_tensors="jax", return_tensors="np",
) )
self.assertEqual(1024, dct["input_ids"].shape[1]) self.assertEqual(1024, dct["input_ids"].shape[1])
......
...@@ -213,7 +213,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes ...@@ -213,7 +213,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
@slow @slow
def test_batch_generation(self): def test_batch_generation(self):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left") tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True) inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2") model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
model.do_sample = False model.do_sample = False
......
...@@ -204,7 +204,7 @@ class FlaxGPTNeoModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unitt ...@@ -204,7 +204,7 @@ class FlaxGPTNeoModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unitt
@slow @slow
def test_batch_generation(self): def test_batch_generation(self):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left") tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left")
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True) inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M") model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
model.do_sample = False model.do_sample = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment