Unverified Commit 74e6111b authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix test and docs (#14399)

parent 4ce74edf
......@@ -795,6 +795,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
Examples::
>>> from transformers import ViTFeatureExtractor, TFViTForImageClassification
>>> import tensorflow as tf
>>> from PIL import Image
>>> import requests
......@@ -809,7 +810,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
"""
inputs = input_processing(
func=self.call,
......
......@@ -371,7 +371,7 @@ class TFViTModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_image_classification_head(self):
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224", from_pt=True)
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
feature_extractor = self.default_feature_extractor
image = prepare_img()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment