"...config/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "66fec2c8fc6156fc33ae908d9994556838d842e2"
Unverified Commit 9db2eebb authored by Johannes Kolbe's avatar Johannes Kolbe Committed by GitHub
Browse files

add vit tf doctest with @add_code_sample_docstrings (#16636)



* add vit tf doctest with @add_code_sample_docstrings

* add labels string back in
Co-authored-by: default avatarJohannes Kolbe <johannes.kolbe@tech.better.team>
parent 4ef0abb7
...@@ -33,14 +33,23 @@ from ...modeling_tf_utils import ( ...@@ -33,14 +33,23 @@ from ...modeling_tf_utils import (
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_vit import ViTConfig from .configuration_vit import ViTConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "ViTConfig" _CONFIG_FOR_DOC = "ViTConfig"
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224" _FEAT_EXTRACTOR_FOR_DOC = "ViTFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
# Inspired by # Inspired by
...@@ -645,7 +654,14 @@ class TFViTModel(TFViTPreTrainedModel): ...@@ -645,7 +654,14 @@ class TFViTModel(TFViTPreTrainedModel):
@unpack_inputs @unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def call( def call(
self, self,
pixel_values: Optional[TFModelInputType] = None, pixel_values: Optional[TFModelInputType] = None,
...@@ -656,26 +672,6 @@ class TFViTModel(TFViTPreTrainedModel): ...@@ -656,26 +672,6 @@ class TFViTModel(TFViTPreTrainedModel):
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: bool = False, training: bool = False,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
r"""
Returns:
Examples:
```python
>>> from transformers import ViTFeatureExtractor, TFViTModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> model = TFViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
>>> inputs = feature_extractor(images=image, return_tensors="tf")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
outputs = self.vit( outputs = self.vit(
pixel_values=pixel_values, pixel_values=pixel_values,
...@@ -744,7 +740,13 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification ...@@ -744,7 +740,13 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
@unpack_inputs @unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def call( def call(
self, self,
pixel_values: Optional[TFModelInputType] = None, pixel_values: Optional[TFModelInputType] = None,
...@@ -761,30 +763,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification ...@@ -761,30 +763,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
Returns:
Examples:
```python
>>> from transformers import ViTFeatureExtractor, TFViTForImageClassification
>>> import tensorflow as tf
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
>>> model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
>>> inputs = feature_extractor(images=image, return_tensors="tf")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
```"""
outputs = self.vit( outputs = self.vit(
pixel_values=pixel_values, pixel_values=pixel_values,
......
...@@ -36,6 +36,7 @@ src/transformers/models/van/modeling_van.py ...@@ -36,6 +36,7 @@ src/transformers/models/van/modeling_van.py
src/transformers/models/vilt/modeling_vilt.py src/transformers/models/vilt/modeling_vilt.py
src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
src/transformers/models/vit/modeling_vit.py src/transformers/models/vit/modeling_vit.py
src/transformers/models/vit/modeling_tf_vit.py
src/transformers/models/vit_mae/modeling_vit_mae.py src/transformers/models/vit_mae/modeling_vit_mae.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/tokenization_wav2vec2.py src/transformers/models/wav2vec2/tokenization_wav2vec2.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment