Unverified Commit 90166121 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add general vision docstrings (#15501)

* Add general docstrings

* Remove legacy docstrings

* Add BEiT

* Add DEiT

* Add SegFormer

* Fix beit output class

* Fix missing return_dict
parent e2b6e73f
...@@ -1288,6 +1288,58 @@ PT_SPEECH_XVECTOR_SAMPLE = r""" ...@@ -1288,6 +1288,58 @@ PT_SPEECH_XVECTOR_SAMPLE = r"""
``` ```
""" """
PT_VISION_BASE_MODEL_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = feature_extractor(image, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
{expected_output}
```
"""
PT_VISION_SEQ_CLASS_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = feature_extractor(image, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
{expected_output}
```
"""
PT_SAMPLE_DOCSTRINGS = { PT_SAMPLE_DOCSTRINGS = {
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE, "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE, "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
...@@ -1301,6 +1353,8 @@ PT_SAMPLE_DOCSTRINGS = { ...@@ -1301,6 +1353,8 @@ PT_SAMPLE_DOCSTRINGS = {
"AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE, "AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
"AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE, "AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
"AudioXVector": PT_SPEECH_XVECTOR_SAMPLE, "AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
"VisionBaseModel": PT_VISION_BASE_MODEL_SAMPLE,
"ImageClassification": PT_VISION_SEQ_CLASS_SAMPLE,
} }
...@@ -1639,8 +1693,12 @@ def add_code_sample_docstrings( ...@@ -1639,8 +1693,12 @@ def add_code_sample_docstrings(
code_sample = sample_docstrings["AudioXVector"] code_sample = sample_docstrings["AudioXVector"]
elif "Model" in model_class and modality == "audio": elif "Model" in model_class and modality == "audio":
code_sample = sample_docstrings["SpeechBaseModel"] code_sample = sample_docstrings["SpeechBaseModel"]
elif "Model" in model_class and modality == "vision":
code_sample = sample_docstrings["VisionBaseModel"]
elif "Model" in model_class or "Encoder" in model_class: elif "Model" in model_class or "Encoder" in model_class:
code_sample = sample_docstrings["BaseModel"] code_sample = sample_docstrings["BaseModel"]
elif "ImageClassification" in model_class:
code_sample = sample_docstrings["ImageClassification"]
else: else:
raise ValueError(f"Docstring can't be built for model {model_class}") raise ValueError(f"Docstring can't be built for model {model_class}")
......
...@@ -25,7 +25,12 @@ from torch import nn ...@@ -25,7 +25,12 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging from ...utils import logging
...@@ -34,8 +39,17 @@ from .configuration_beit import BeitConfig ...@@ -34,8 +39,17 @@ from .configuration_beit import BeitConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "BeitConfig" _CONFIG_FOR_DOC = "BeitConfig"
_CHECKPOINT_FOR_DOC = "microsoft/beit-base-patch16-224" _FEAT_EXTRACTOR_FOR_DOC = "BeitFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "microsoft/beit-base-patch16-224-pt22k"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "microsoft/beit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'"
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/beit-base-patch16-224", "microsoft/beit-base-patch16-224",
...@@ -613,7 +627,14 @@ class BeitModel(BeitPreTrainedModel): ...@@ -613,7 +627,14 @@ class BeitModel(BeitPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BeitModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BeitModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -623,26 +644,6 @@ class BeitModel(BeitPreTrainedModel): ...@@ -623,26 +644,6 @@ class BeitModel(BeitPreTrainedModel):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ):
r"""
Returns:
Examples:
```python
>>> from transformers import BeitFeatureExtractor, BeitModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
>>> model = BeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -813,7 +814,13 @@ class BeitForImageClassification(BeitPreTrainedModel): ...@@ -813,7 +814,13 @@ class BeitForImageClassification(BeitPreTrainedModel):
self.post_init() self.post_init()
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -828,31 +835,8 @@ class BeitForImageClassification(BeitPreTrainedModel): ...@@ -828,31 +835,8 @@ class BeitForImageClassification(BeitPreTrainedModel):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
Returns:
Examples:
```python
>>> from transformers import BeitFeatureExtractor, BeitForImageClassification
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224")
>>> model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.beit( outputs = self.beit(
pixel_values, pixel_values,
head_mask=head_mask, head_mask=head_mask,
......
...@@ -28,6 +28,7 @@ from torch.nn import CrossEntropyLoss, MSELoss ...@@ -28,6 +28,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
...@@ -40,8 +41,18 @@ from .configuration_deit import DeiTConfig ...@@ -40,8 +41,18 @@ from .configuration_deit import DeiTConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "DeiTConfig" _CONFIG_FOR_DOC = "DeiTConfig"
_FEAT_EXTRACTOR_FOR_DOC = "DeiTFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/deit-base-distilled-patch16-224" _CHECKPOINT_FOR_DOC = "facebook/deit-base-distilled-patch16-224"
_EXPECTED_OUTPUT_SHAPE = [1, 198, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/deit-base-distilled-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'"
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/deit-base-distilled-patch16-224", "facebook/deit-base-distilled-patch16-224",
...@@ -462,7 +473,14 @@ class DeiTModel(DeiTPreTrainedModel): ...@@ -462,7 +473,14 @@ class DeiTModel(DeiTPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -471,26 +489,6 @@ class DeiTModel(DeiTPreTrainedModel): ...@@ -471,26 +489,6 @@ class DeiTModel(DeiTPreTrainedModel):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ):
r"""
Returns:
Examples:
```python
>>> from transformers import DeiTFeatureExtractor, DeiTModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = DeiTFeatureExtractor.from_pretrained("facebook/deit-base-distilled-patch16-224")
>>> model = DeiTModel.from_pretrained("facebook/deit-base-distilled-patch16-224", add_pooling_layer=False)
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -707,7 +705,13 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel): ...@@ -707,7 +705,13 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
self.post_init() self.post_init()
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DeiTForImageClassificationWithTeacherOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=DeiTForImageClassificationWithTeacherOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -716,29 +720,6 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel): ...@@ -716,29 +720,6 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ):
"""
Returns:
Examples:
```python
>>> from transformers import DeiTFeatureExtractor, DeiTForImageClassificationWithTeacher
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = DeiTFeatureExtractor.from_pretrained("facebook/deit-base-distilled-patch16-224")
>>> model = DeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deit( outputs = self.deit(
......
...@@ -24,7 +24,12 @@ from torch import nn ...@@ -24,7 +24,12 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging from ...utils import logging
...@@ -33,8 +38,18 @@ from .configuration_segformer import SegformerConfig ...@@ -33,8 +38,18 @@ from .configuration_segformer import SegformerConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "nvidia/segformer-b0-finetuned-ade-512-512"
# General docstring
_CONFIG_FOR_DOC = "SegformerConfig" _CONFIG_FOR_DOC = "SegformerConfig"
_FEAT_EXTRACTOR_FOR_DOC = "SegformerFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "nvidia/mit-b0"
_EXPECTED_OUTPUT_SHAPE = [1, 256, 256]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0"
_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'"
SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"nvidia/segformer-b0-finetuned-ade-512-512", "nvidia/segformer-b0-finetuned-ade-512-512",
...@@ -478,29 +493,15 @@ class SegformerModel(SegformerPreTrainedModel): ...@@ -478,29 +493,15 @@ class SegformerModel(SegformerPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None): def forward(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
r"""
Returns:
Examples:
```python
>>> from transformers import SegformerFeatureExtractor, SegformerModel
>>> from PIL import Image
>>> import requests
>>> feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/mit-b0")
>>> model = SegformerModel.from_pretrained("nvidia/mit-b0")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> sequence_output = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -546,7 +547,13 @@ class SegformerForImageClassification(SegformerPreTrainedModel): ...@@ -546,7 +547,13 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
self.post_init() self.post_init()
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -560,29 +567,7 @@ class SegformerForImageClassification(SegformerPreTrainedModel): ...@@ -560,29 +567,7 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
Returns:
Examples:
```python
>>> from transformers import SegformerFeatureExtractor, SegformerForImageClassification
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/mit-b0")
>>> model = SegformerForImageClassification.from_pretrained("nvidia/mit-b0")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.segformer( outputs = self.segformer(
......
...@@ -24,7 +24,7 @@ from torch import nn ...@@ -24,7 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging from ...utils import logging
...@@ -33,8 +33,18 @@ from .configuration_swin import SwinConfig ...@@ -33,8 +33,18 @@ from .configuration_swin import SwinConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "microsoft/swin-tiny-patch4-window7-224" # General docstring
_CONFIG_FOR_DOC = "SwinConfig" _CONFIG_FOR_DOC = "SwinConfig"
_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "microsoft/swin-tiny-patch4-window7-224"
_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "microsoft/swin-tiny-patch4-window7-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'"
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [ SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/swin-tiny-patch4-window7-224", "microsoft/swin-tiny-patch4-window7-224",
...@@ -686,7 +696,14 @@ class SwinModel(SwinPreTrainedModel): ...@@ -686,7 +696,14 @@ class SwinModel(SwinPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -695,27 +712,6 @@ class SwinModel(SwinPreTrainedModel): ...@@ -695,27 +712,6 @@ class SwinModel(SwinPreTrainedModel):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ):
r"""
Returns:
Examples:
```python
>>> from transformers import AutoFeatureExtractor, SwinModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -784,7 +780,13 @@ class SwinForImageClassification(SwinPreTrainedModel): ...@@ -784,7 +780,13 @@ class SwinForImageClassification(SwinPreTrainedModel):
self.post_init() self.post_init()
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -799,30 +801,7 @@ class SwinForImageClassification(SwinPreTrainedModel): ...@@ -799,30 +801,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
Returns:
Examples:
```python
>>> from transformers import AutoFeatureExtractor, SwinForImageClassification
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> model = SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.swin( outputs = self.swin(
......
...@@ -24,7 +24,7 @@ from torch import nn ...@@ -24,7 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging from ...utils import logging
...@@ -33,8 +33,18 @@ from .configuration_vit import ViTConfig ...@@ -33,8 +33,18 @@ from .configuration_vit import ViTConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "ViTConfig" _CONFIG_FOR_DOC = "ViTConfig"
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224" _FEAT_EXTRACTOR_FOR_DOC = "ViTFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "'Egyptian cat'"
VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/vit-base-patch16-224", "google/vit-base-patch16-224",
...@@ -491,7 +501,14 @@ class ViTModel(ViTPreTrainedModel): ...@@ -491,7 +501,14 @@ class ViTModel(ViTPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -501,26 +518,6 @@ class ViTModel(ViTPreTrainedModel): ...@@ -501,26 +518,6 @@ class ViTModel(ViTPreTrainedModel):
interpolate_pos_encoding=None, interpolate_pos_encoding=None,
return_dict=None, return_dict=None,
): ):
r"""
Returns:
Examples:
```python
>>> from transformers import ViTFeatureExtractor, ViTModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -597,7 +594,13 @@ class ViTForImageClassification(ViTPreTrainedModel): ...@@ -597,7 +594,13 @@ class ViTForImageClassification(ViTPreTrainedModel):
self.post_init() self.post_init()
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -613,29 +616,7 @@ class ViTForImageClassification(ViTPreTrainedModel): ...@@ -613,29 +616,7 @@ class ViTForImageClassification(ViTPreTrainedModel):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
Returns:
Examples:
```python
>>> from transformers import ViTFeatureExtractor, ViTForImageClassification
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
>>> model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vit( outputs = self.vit(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment