Unverified Commit 111fe85d authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Model doc revamp - first PR (#5821)

* First PR for model doc revamp

* Deactivating fail on warning, temporarily

* Remove commnet

* Minor changes

* Typos

* Added TODO in Makefile

* Keep old models.rst file intact, move new docs into new models_new.rst file
parent c888d6dc
......@@ -15,6 +15,7 @@ docs/build
docs/source/auto_examples/
docs/source/gen_modules/
docs/source/generated/
docs/source/models/generated/
# pytorch-sphinx-theme gets installed here
docs/src
......
......@@ -6,7 +6,9 @@ ifneq ($(EXAMPLES_PATTERN),)
endif
# You can set these variables from the command line.
SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS)
# TODO: Once the models doc revamp is done, set back the -W option to raise
# errors on warnings. See https://github.com/pytorch/vision/pull/5821#discussion_r850500693
SPHINXOPTS = -j auto $(EXAMPLES_PATTERN_OPTS)
SPHINXBUILD = sphinx-build
SPHINXPROJ = torchvision
SOURCEDIR = source
......
......@@ -3,6 +3,7 @@ numpy
sphinx-copybutton>=0.3.1
sphinx-gallery>=0.9.0
sphinx==3.5.4
tabulate
# This pin is only needed for sphinx<4.0.2. See https://github.com/pytorch/vision/issues/5673 for details
Jinja2<3.1.*
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
......@@ -21,9 +21,13 @@
# sys.path.insert(0, os.path.abspath('.'))
import os
import textwrap
from pathlib import Path
import pytorch_sphinx_theme
import torchvision
import torchvision.models as M
from tabulate import tabulate
# -- General configuration ------------------------------------------------
......@@ -292,5 +296,57 @@ def inject_minigalleries(app, what, name, obj, options, lines):
lines.append("\n")
def inject_weight_metadata(app, what, name, obj, options, lines):
if obj.__name__.endswith("_Weights"):
lines[:] = ["The model builder above accepts the following values as the ``weights`` parameter:"]
lines.append("")
for field in obj:
lines += [f"**{str(field)}**:", ""]
table = []
for k, v in field.meta.items():
if k == "categories":
continue
elif k == "recipe":
v = f"`link <{v}>`__"
table.append((str(k), str(v)))
table = tabulate(table, tablefmt="rst")
lines += [".. table::", ""]
lines += textwrap.indent(table, " " * 4).split("\n")
lines.append("")
def generate_classification_table():
weight_enums = [getattr(M, name) for name in dir(M) if name.endswith("_Weights")]
weights = [w for weight_enum in weight_enums for w in weight_enum]
column_names = ("**Weight**", "**Acc@1**", "**Acc@5**", "**Params**", "**Recipe**")
content = [
(
f":class:`{w} <{type(w).__name__}>`",
w.meta["acc@1"],
w.meta["acc@5"],
f"{w.meta['num_params']/1e6:.1f}M",
f"`link <{w.meta['recipe']}>`__",
)
for w in weights
]
table = tabulate(content, headers=column_names, tablefmt="rst")
generated_dir = Path("generated")
generated_dir.mkdir(exist_ok=True)
with open(generated_dir / "classification_table.rst", "w+") as table_file:
table_file.write(".. table::\n")
table_file.write(" :widths: 100 10 10 20 10\n\n")
table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")
generate_classification_table()
def setup(app):
app.connect("autodoc-process-docstring", inject_minigalleries)
app.connect("autodoc-process-docstring", inject_weight_metadata)
......@@ -38,6 +38,7 @@ architectures, and common image transformations for computer vision.
ops
io
feature_extraction
models_new
.. toctree::
:maxdepth: 1
......
ResNet
======
.. currentmodule:: torchvision.models
The ResNet model is based on the `Deep Residual Learning for Image Recognition
<https://arxiv.org/abs/1512.03385>`_ paper.
Model builders
--------------
The following model builders can be used to instanciate a ResNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.resnet.ResNet`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
resnet18
resnet34
resnet50
resnet101
resnet152
VGG
===
.. currentmodule:: torchvision.models
The VGG model is based on the `Very Deep Convolutional Networks for Large-Scale
Image Recognition <https://arxiv.org/abs/1409.1556>`_ paper.
Model builders
--------------
The following model builders can be used to instanciate a VGG model, with or
without pre-trained weights. All the model buidlers internally rely on the
``torchvision.models.vgg.VGG`` base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
vgg11
vgg11_bn
vgg13
vgg13_bn
vgg16
vgg16_bn
vgg19
vgg19_bn
.. _models_new:
Models and pre-trained weights - New
####################################
.. note::
These are the new models docs, documenting the new multi-weight API.
TODO: Once all is done, remove the "- New" part in the title above, and
rename this file as models.rst
The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person
keypoint detection, video classification, and optical flow.
.. note ::
Backward compatibility is guaranteed for loading a serialized
``state_dict`` to the model created using old PyTorch version.
On the contrary, loading entire saved models or serialized
``ScriptModules`` (seralized using older versions of PyTorch)
may not preserve the historic behaviour. Refer to the following
`documentation
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
Classification
==============
.. currentmodule:: torchvision.models
The following classification models are available, with or without pre-trained
weights:
.. toctree::
:maxdepth: 1
models/resnet
models/vgg
Table of all available classification weights
---------------------------------------------
Accuracies are reported on ImageNet
.. include:: generated/classification_table.rst
Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================
TODO: Something similar to classification models: list of models + table of weights
......@@ -556,12 +556,23 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
Args:
weights (ResNet18_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.ResNet18_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ResNet18_Weights
:members:
"""
weights = ResNet18_Weights.verify(weights)
......@@ -570,12 +581,23 @@ def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = Tru
@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
Args:
weights (ResNet34_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.ResNet34_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ResNet34_Weights
:members:
"""
weights = ResNet34_Weights.verify(weights)
......@@ -584,12 +606,23 @@ def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = Tru
@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
Args:
weights (ResNet50_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.ResNet50_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.ResNet50_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ResNet50_Weights
:members:
"""
weights = ResNet50_Weights.verify(weights)
......@@ -598,12 +631,23 @@ def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = Tru
@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
Args:
weights (ResNet101_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.ResNet101_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ResNet101_Weights
:members:
"""
weights = ResNet101_Weights.verify(weights)
......@@ -612,12 +656,23 @@ def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = T
@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
Args:
weights (ResNet152_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.ResNet152_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.ResNet152_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ResNet152_Weights
:members:
"""
weights = ResNet152_Weights.verify(weights)
......
......@@ -252,13 +252,23 @@ class VGG19_BN_Weights(WeightsEnum):
@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1))
def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 11-layer model (configuration "A") from
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
"""VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (VGG11_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.VGG11_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG11_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG11_Weights
:members:
"""
weights = VGG11_Weights.verify(weights)
......@@ -267,13 +277,23 @@ def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **k
@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1))
def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 11-layer model (configuration "A") with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
"""VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (VGG11_BN_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.VGG11_BN_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG11_BN_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG11_BN_Weights
:members:
"""
weights = VGG11_BN_Weights.verify(weights)
......@@ -282,13 +302,23 @@ def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = Tru
@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1))
def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 13-layer model (configuration "B")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
"""VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (VGG13_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.VGG13_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG13_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG13_Weights
:members:
"""
weights = VGG13_Weights.verify(weights)
......@@ -297,13 +327,23 @@ def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **k
@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1))
def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 13-layer model (configuration "B") with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
"""VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (VGG13_BN_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.VGG13_BN_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG13_BN_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG13_BN_Weights
:members:
"""
weights = VGG13_BN_Weights.verify(weights)
......@@ -312,13 +352,23 @@ def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = Tru
@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 16-layer model (configuration "D")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
"""VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (VGG16_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.VGG16_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG16_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG16_Weights
:members:
"""
weights = VGG16_Weights.verify(weights)
......@@ -327,13 +377,23 @@ def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **k
@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 16-layer model (configuration "D") with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
"""VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (VGG16_BN_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG16_BN_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG16_BN_Weights
:members:
"""
weights = VGG16_BN_Weights.verify(weights)
......@@ -342,13 +402,23 @@ def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = Tru
@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 19-layer model (configuration "E")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
"""VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (VGG19_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.VGG19_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG19_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG19_Weights
:members:
"""
weights = VGG19_Weights.verify(weights)
......@@ -357,13 +427,23 @@ def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **k
@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 19-layer model (configuration 'E') with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32.
"""VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
Args:
weights (VGG19_BN_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.VGG19_BN_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
for more details about this class.
.. autoclass:: torchvision.models.VGG19_BN_Weights
:members:
"""
weights = VGG19_BN_Weights.verify(weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment