Unverified Commit 96b5d7db authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Remove pt-like calls on tf tensor (#18393)

parent 679d68a1
......@@ -934,9 +934,9 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: maillot
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
Predicted class: ptarmigan
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment