Unverified Commit ca265374 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Adding revamp docs for vision_transformers and regnet (#5856)

* Add docs for regnet, still need to update the comment docs on models

* Fix a little typo on .rst file

* Update regnet docstring

* Add vision_transformer docs, and fix typo on regnet docs

* Update docstring to make sure it does not exceed 120 chars per line

* Improve formatting

* Change the new line location for vision_transformer docstring
parent a64c6749
RegNet
======
.. currentmodule:: torchvision.models
The RegNet model is based on the `Designing Network Design Spaces
<https://arxiv.org/abs/2003.13678>`_ paper.
Model builders
--------------
The following model builders can be used to instantiate a RegNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.regnet.RegNet`` base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
regnet_y_400mf
regnet_y_800mf
regnet_y_1_6gf
regnet_y_3_2gf
regnet_y_8gf
regnet_y_16gf
regnet_y_32gf
regnet_y_128gf
regnet_x_400mf
regnet_x_800mf
regnet_x_1_6gf
regnet_x_3_2gf
regnet_x_8gf
regnet_x_16gf
regnet_x_32gf
......@@ -10,7 +10,7 @@ The ResNet model is based on the `Deep Residual Learning for Image Recognition
Model builders
--------------
The following model builders can be used to instanciate a ResNet model, with or
The following model builders can be used to instantiate a ResNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.resnet.ResNet`` base class. Please refer to the `source
code
......
......@@ -11,7 +11,7 @@ paper.
Model builders
--------------
The following model builders can be used to instanciate a SqueezeNet model, with or
The following model builders can be used to instantiate a SqueezeNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.squeezenet.SqueezeNet`` base class. Please refer to the `source
code
......
......@@ -10,7 +10,7 @@ Image Recognition <https://arxiv.org/abs/1409.1556>`_ paper.
Model builders
--------------
The following model builders can be used to instanciate a VGG model, with or
The following model builders can be used to instantiate a VGG model, with or
without pre-trained weights. All the model buidlers internally rely on the
``torchvision.models.vgg.VGG`` base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for
......
VisionTransformer
=================
.. currentmodule:: torchvision.models
The VisionTransformer model is based on the `An Image is Worth 16x16 Words:
Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_ paper.
Model builders
--------------
The following model builders can be used to instantiate a VisionTransformer model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.vision_transformer.VisionTransformer`` base class.
Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
vit_b_16
vit_b_32
vit_l_16
vit_l_32
vit_h_14
......@@ -36,9 +36,11 @@ weights:
.. toctree::
:maxdepth: 1
models/regnet
models/resnet
models/squeezenet
models/vgg
models/vision_transformer
Table of all available classification weights
......
......@@ -861,11 +861,20 @@ class RegNet_X_32GF_Weights(WeightsEnum):
def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetY_400MF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (RegNet_Y_400MF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`torchvision.models.regnet.RegNet_Y_400MF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_Y_400MF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_Y_400MF_Weights
:members:
"""
weights = RegNet_Y_400MF_Weights.verify(weights)
......@@ -877,11 +886,20 @@ def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress
def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetY_800MF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (RegNet_Y_800MF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`torchvision.models.regnet.RegNet_Y_800MF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_Y_800MF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_Y_800MF_Weights
:members:
"""
weights = RegNet_Y_800MF_Weights.verify(weights)
......@@ -893,11 +911,20 @@ def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress
def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetY_1.6GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (RegNet_Y_1_6GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`torchvision.models.regnet.RegNet_Y_1_6GF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_Y_1_6GF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_Y_1_6GF_Weights
:members:
"""
weights = RegNet_Y_1_6GF_Weights.verify(weights)
......@@ -911,11 +938,20 @@ def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress
def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetY_3.2GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (RegNet_Y_3_2GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`torchvision.models.regnet.RegNet_Y_3_2GF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_Y_3_2GF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_Y_3_2GF_Weights
:members:
"""
weights = RegNet_Y_3_2GF_Weights.verify(weights)
......@@ -929,11 +965,20 @@ def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress
def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetY_8GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (RegNet_Y_8GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`torchvision.models.regnet.RegNet_Y_8GF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_Y_8GF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_Y_8GF_Weights
:members:
"""
weights = RegNet_Y_8GF_Weights.verify(weights)
......@@ -947,11 +992,20 @@ def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bo
def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetY_16GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (RegNet_Y_16GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`torchvision.models.regnet.RegNet_Y_16GF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_Y_16GF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_Y_16GF_Weights
:members:
"""
weights = RegNet_Y_16GF_Weights.verify(weights)
......@@ -965,11 +1019,20 @@ def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress:
def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetY_32GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (RegNet_Y_32GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`torchvision.models.regnet.RegNet_Y_32GF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_Y_32GF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_Y_32GF_Weights
:members:
"""
weights = RegNet_Y_32GF_Weights.verify(weights)
......@@ -983,12 +1046,20 @@ def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress:
def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetY_128GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
NOTE: Pretrained weights are not available for this model.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (RegNet_Y_128GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`torchvision.models.regnet.RegNet_Y_128GF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_Y_128GF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_Y_128GF_Weights
:members:
"""
weights = RegNet_Y_128GF_Weights.verify(weights)
......@@ -1002,11 +1073,20 @@ def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress
def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetX_400MF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (RegNet_X_400MF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`torchvision.models.regnet.RegNet_X_400MF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_X_400MF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_X_400MF_Weights
:members:
"""
weights = RegNet_X_400MF_Weights.verify(weights)
......@@ -1018,11 +1098,20 @@ def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress
def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetX_800MF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (RegNet_X_800MF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`torchvision.models.regnet.RegNet_X_800MF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_X_800MF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_X_800MF_Weights
:members:
"""
weights = RegNet_X_800MF_Weights.verify(weights)
......@@ -1034,7 +1123,20 @@ def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress
def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetX_1.6GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (:class:`torchvision.models.regnet.RegNet_X_1_6GF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_X_1_6GF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_X_1_6GF_Weights
:members:
Args:
weights (RegNet_X_1_6GF_Weights, optional): The pretrained weights for the model
......@@ -1050,7 +1152,20 @@ def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress
def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetX_3.2GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (:class:`torchvision.models.regnet.RegNet_X_3_2GF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_X_3_2GF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_X_3_2GF_Weights
:members:
Args:
weights (RegNet_X_3_2GF_Weights, optional): The pretrained weights for the model
......@@ -1066,7 +1181,20 @@ def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress
def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetX_8GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (:class:`torchvision.models.regnet.RegNet_X_8GF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_X_8GF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_X_8GF_Weights
:members:
Args:
weights (RegNet_X_8GF_Weights, optional): The pretrained weights for the model
......@@ -1082,7 +1210,20 @@ def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bo
def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetX_16GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (:class:`torchvision.models.regnet.RegNet_X_16GF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_X_16GF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_X_16GF_Weights
:members:
Args:
weights (RegNet_X_16GF_Weights, optional): The pretrained weights for the model
......@@ -1098,7 +1239,20 @@ def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress:
def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetX_32GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
`Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
Args:
weights (:class:`torchvision.models.regnet.RegNet_X_32GF_Weights`, optional): The pretrained weights to use.
See :class:`~torchvision.models.regnet.RegNet_X_32GF_Weights` below for more details and possible values.
By default, no pretrained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
for more detail about the classes.
.. autoclass:: torchvision.models.regnet.RegNet_X_32GF_Weights
:members:
Args:
weights (RegNet_X_32GF_Weights, optional): The pretrained weights for the model
......
......@@ -490,11 +490,20 @@ class ViT_H_14_Weights(WeightsEnum):
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_b_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (ViT_B_16_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.vision_transformer.ViT_B_16_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.vision_transformer.ViT_B_16_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.vision_transformer.ViT_B_16_Weights
:members:
"""
weights = ViT_B_16_Weights.verify(weights)
......@@ -514,11 +523,20 @@ def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = Tru
def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_b_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (ViT_B_32_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.vision_transformer.ViT_B_32_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.vision_transformer.ViT_B_32_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.vision_transformer.ViT_B_32_Weights
:members:
"""
weights = ViT_B_32_Weights.verify(weights)
......@@ -538,11 +556,20 @@ def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = Tru
def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_l_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (ViT_L_16_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.vision_transformer.ViT_L_16_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.vision_transformer.ViT_L_16_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.vision_transformer.ViT_L_16_Weights
:members:
"""
weights = ViT_L_16_Weights.verify(weights)
......@@ -562,11 +589,20 @@ def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = Tru
def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_l_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (ViT_L_32_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.vision_transformer.ViT_L_32_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.vision_transformer.ViT_L_32_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.vision_transformer.ViT_L_32_Weights
:members:
"""
weights = ViT_L_32_Weights.verify(weights)
......@@ -585,11 +621,20 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_h_14 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (ViT_H_14_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.vision_transformer.ViT_H_14_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.vision_transformer.ViT_H_14_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.vision_transformer.ViT_H_14_Weights
:members:
"""
weights = ViT_H_14_Weights.verify(weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment