Unverified Commit 96b5d7db authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Remove pt-like calls on tf tensor (#18393)

parent 679d68a1
...@@ -934,9 +934,9 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati ...@@ -934,9 +934,9 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> logits = outputs.logits >>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes >>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item() >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[predicted_class_idx]) >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
Predicted class: maillot Predicted class: ptarmigan
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment