"...text-generation-inference.git" did not exist on "895c5f15628df870f7a2ced7151dedb84231a996"
Unverified Commit f0148413 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add revamped docs for video classification models (#5894)

* Add revamped docs for video classification models

* EOL
parent 36c46357
......@@ -379,6 +379,7 @@ generate_weights_table(module=M.detection, table_name="detection", metrics=[("bo
generate_weights_table(
module=M.segmentation, table_name="segmentation", metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")]
)
generate_weights_table(module=M.video, table_name="video", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])
def setup(app):
......
Video ResNet
============
.. currentmodule:: torchvision.models.video
The VideoResNet model is based on the `A Closer Look at Spatiotemporal
Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__ paper.
Model builders
--------------
The following model builders can be used to instantiate a VideoResNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.video.resnet.VideoResNet`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
r3d_18
mc3_18
r2plus1d_18
......@@ -101,3 +101,24 @@ Table of all available detection weights
Box MAPs are reported on COCO
.. include:: generated/detection_table.rst
Video Classification
====================
.. currentmodule:: torchvision.models.video
The following video classification models are available, with or without
pre-trained weights:
.. toctree::
:maxdepth: 1
models/video_resnet
Table of all available video classification weights
---------------------------------------------------
Accuracies are reported on Kinetics-400
.. include:: generated/video_table.rst
......@@ -365,15 +365,24 @@ class R2Plus1D_18_Weights(WeightsEnum):
@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1))
def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Construct 18 layer Resnet3D model as in
https://arxiv.org/abs/1711.11248
"""Construct 18 layer Resnet3D model.
Args:
weights (R3D_18_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
Returns:
VideoResNet: R3D-18 network
Args:
weights (:class:`~torchvision.models.video.R3D_18_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.R3D_18_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.R3D_18_Weights
:members:
"""
weights = R3D_18_Weights.verify(weights)
......@@ -390,15 +399,24 @@ def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, *
@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1))
def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Constructor for 18 layer Mixed Convolution network as in
https://arxiv.org/abs/1711.11248
"""Construct 18 layer Mixed Convolution network as in
Args:
weights (MC3_18_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
Returns:
VideoResNet: MC3 Network definition
Args:
weights (:class:`~torchvision.models.video.MC3_18_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.MC3_18_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.MC3_18_Weights
:members:
"""
weights = MC3_18_Weights.verify(weights)
......@@ -415,15 +433,24 @@ def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, *
@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1))
def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Constructor for the 18 layer deep R(2+1)D network as in
https://arxiv.org/abs/1711.11248
"""Construct 18 layer deep R(2+1)D network as in
Args:
weights (R2Plus1D_18_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
Returns:
VideoResNet: R(2+1)D-18 network
Args:
weights (:class:`~torchvision.models.video.R2Plus1D_18_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.R2Plus1D_18_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.R2Plus1D_18_Weights
:members:
"""
weights = R2Plus1D_18_Weights.verify(weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment