Unverified Commit 90166121 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add general vision docstrings (#15501)

* Add general docstrings

* Remove legacy docstrings

* Add BEiT

* Add DEiT

* Add SegFormer

* Fix beit output class

* Fix missing return_dict
parent e2b6e73f
......@@ -1288,6 +1288,58 @@ PT_SPEECH_XVECTOR_SAMPLE = r"""
```
"""
PT_VISION_BASE_MODEL_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = feature_extractor(image, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
{expected_output}
```
"""
PT_VISION_SEQ_CLASS_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = feature_extractor(image, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
{expected_output}
```
"""
PT_SAMPLE_DOCSTRINGS = {
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
......@@ -1301,6 +1353,8 @@ PT_SAMPLE_DOCSTRINGS = {
"AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
"AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
"AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
"VisionBaseModel": PT_VISION_BASE_MODEL_SAMPLE,
"ImageClassification": PT_VISION_SEQ_CLASS_SAMPLE,
}
......@@ -1639,8 +1693,12 @@ def add_code_sample_docstrings(
code_sample = sample_docstrings["AudioXVector"]
elif "Model" in model_class and modality == "audio":
code_sample = sample_docstrings["SpeechBaseModel"]
elif "Model" in model_class and modality == "vision":
code_sample = sample_docstrings["VisionBaseModel"]
elif "Model" in model_class or "Encoder" in model_class:
code_sample = sample_docstrings["BaseModel"]
elif "ImageClassification" in model_class:
code_sample = sample_docstrings["ImageClassification"]
else:
raise ValueError(f"Docstring can't be built for model {model_class}")
......
......@@ -25,7 +25,12 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
......@@ -34,8 +39,17 @@ from .configuration_beit import BeitConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "BeitConfig"
_CHECKPOINT_FOR_DOC = "microsoft/beit-base-patch16-224"
_FEAT_EXTRACTOR_FOR_DOC = "BeitFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "microsoft/beit-base-patch16-224-pt22k"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "microsoft/beit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'"
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/beit-base-patch16-224",
......@@ -613,7 +627,14 @@ class BeitModel(BeitPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BeitModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BeitModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values=None,
......@@ -623,26 +644,6 @@ class BeitModel(BeitPreTrainedModel):
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
Examples:
```python
>>> from transformers import BeitFeatureExtractor, BeitModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
>>> model = BeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -813,7 +814,13 @@ class BeitForImageClassification(BeitPreTrainedModel):
self.post_init()
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values=None,
......@@ -828,31 +835,8 @@ class BeitForImageClassification(BeitPreTrainedModel):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Examples:
```python
>>> from transformers import BeitFeatureExtractor, BeitForImageClassification
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224")
>>> model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
```"""
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.beit(
pixel_values,
head_mask=head_mask,
......
......@@ -28,6 +28,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
......@@ -40,8 +41,18 @@ from .configuration_deit import DeiTConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "DeiTConfig"
_FEAT_EXTRACTOR_FOR_DOC = "DeiTFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/deit-base-distilled-patch16-224"
_EXPECTED_OUTPUT_SHAPE = [1, 198, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/deit-base-distilled-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'"
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/deit-base-distilled-patch16-224",
......@@ -462,7 +473,14 @@ class DeiTModel(DeiTPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values=None,
......@@ -471,26 +489,6 @@ class DeiTModel(DeiTPreTrainedModel):
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
Examples:
```python
>>> from transformers import DeiTFeatureExtractor, DeiTModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = DeiTFeatureExtractor.from_pretrained("facebook/deit-base-distilled-patch16-224")
>>> model = DeiTModel.from_pretrained("facebook/deit-base-distilled-patch16-224", add_pooling_layer=False)
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -707,7 +705,13 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
self.post_init()
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DeiTForImageClassificationWithTeacherOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=DeiTForImageClassificationWithTeacherOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values=None,
......@@ -716,29 +720,6 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
output_hidden_states=None,
return_dict=None,
):
"""
Returns:
Examples:
```python
>>> from transformers import DeiTFeatureExtractor, DeiTForImageClassificationWithTeacher
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = DeiTFeatureExtractor.from_pretrained("facebook/deit-base-distilled-patch16-224")
>>> model = DeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deit(
......
......@@ -24,7 +24,12 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
......@@ -33,8 +38,18 @@ from .configuration_segformer import SegformerConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "nvidia/segformer-b0-finetuned-ade-512-512"
# General docstring
_CONFIG_FOR_DOC = "SegformerConfig"
_FEAT_EXTRACTOR_FOR_DOC = "SegformerFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "nvidia/mit-b0"
_EXPECTED_OUTPUT_SHAPE = [1, 256, 256]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0"
_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'"
SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"nvidia/segformer-b0-finetuned-ade-512-512",
......@@ -478,29 +493,15 @@ class SegformerModel(SegformerPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
r"""
Returns:
Examples:
```python
>>> from transformers import SegformerFeatureExtractor, SegformerModel
>>> from PIL import Image
>>> import requests
>>> feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/mit-b0")
>>> model = SegformerModel.from_pretrained("nvidia/mit-b0")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> sequence_output = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -546,7 +547,13 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
self.post_init()
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values=None,
......@@ -560,29 +567,7 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Examples:
```python
>>> from transformers import SegformerFeatureExtractor, SegformerForImageClassification
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/mit-b0")
>>> model = SegformerForImageClassification.from_pretrained("nvidia/mit-b0")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
```"""
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.segformer(
......
......@@ -24,7 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
......@@ -33,8 +33,18 @@ from .configuration_swin import SwinConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "microsoft/swin-tiny-patch4-window7-224"
# General docstring
_CONFIG_FOR_DOC = "SwinConfig"
_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "microsoft/swin-tiny-patch4-window7-224"
_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "microsoft/swin-tiny-patch4-window7-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'"
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/swin-tiny-patch4-window7-224",
......@@ -686,7 +696,14 @@ class SwinModel(SwinPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values=None,
......@@ -695,27 +712,6 @@ class SwinModel(SwinPreTrainedModel):
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
Examples:
```python
>>> from transformers import AutoFeatureExtractor, SwinModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -784,7 +780,13 @@ class SwinForImageClassification(SwinPreTrainedModel):
self.post_init()
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values=None,
......@@ -799,30 +801,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Examples:
```python
>>> from transformers import AutoFeatureExtractor, SwinForImageClassification
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> model = SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
```"""
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.swin(
......
......@@ -24,7 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
......@@ -33,8 +33,18 @@ from .configuration_vit import ViTConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "ViTConfig"
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224"
_FEAT_EXTRACTOR_FOR_DOC = "ViTFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "'Egyptian cat'"
VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/vit-base-patch16-224",
......@@ -491,7 +501,14 @@ class ViTModel(ViTPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values=None,
......@@ -501,26 +518,6 @@ class ViTModel(ViTPreTrainedModel):
interpolate_pos_encoding=None,
return_dict=None,
):
r"""
Returns:
Examples:
```python
>>> from transformers import ViTFeatureExtractor, ViTModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -597,7 +594,13 @@ class ViTForImageClassification(ViTPreTrainedModel):
self.post_init()
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values=None,
......@@ -613,29 +616,7 @@ class ViTForImageClassification(ViTPreTrainedModel):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Examples:
```python
>>> from transformers import ViTFeatureExtractor, ViTForImageClassification
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
>>> model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
```"""
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vit(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment