Unverified Commit 5985504c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Doc revamp for optical flow models (#5895)

* Doc revamp for optical flow models

* Some more
parent 2ec0e847
...@@ -347,6 +347,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines): ...@@ -347,6 +347,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
metrics = meta.pop("_metrics") metrics = meta.pop("_metrics")
for dataset, dataset_metrics in metrics.items(): for dataset, dataset_metrics in metrics.items():
for metric_name, metric_value in dataset_metrics.items(): for metric_name, metric_value in dataset_metrics.items():
metric_name = metric_name.replace("_", "-")
table.append((f"{metric_name} (on {dataset})", str(metric_value))) table.append((f"{metric_name} (on {dataset})", str(metric_value)))
for k, v in meta.items(): for k, v in meta.items():
......
RAFT
====
.. currentmodule:: torchvision.models.optical_flow
The RAFT model is based on the `RAFT: Recurrent All-Pairs Field Transforms for
Optical Flow <https://arxiv.org/abs/2003.12039>`__ paper.
Model builders
--------------
The following model builders can be used to instantiate a RAFT model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.optical_flow.RAFT`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
raft_large
raft_small
...@@ -376,6 +376,7 @@ Box MAPs are reported on COCO val2017: ...@@ -376,6 +376,7 @@ Box MAPs are reported on COCO val2017:
.. include:: generated/detection_table.rst .. include:: generated/detection_table.rst
Instance Segmentation Instance Segmentation
--------------------- ---------------------
...@@ -481,6 +482,18 @@ Accuracies are reported on Kinetics-400 using single crops for clip length 16: ...@@ -481,6 +482,18 @@ Accuracies are reported on Kinetics-400 using single crops for clip length 16:
.. include:: generated/video_table.rst .. include:: generated/video_table.rst
Optical Flow
============
.. currentmodule:: torchvision.models.optical_flow
The following Optical Flow models are available, with or without pre-trained
.. toctree::
:maxdepth: 1
models/raft
Using models from Hub Using models from Hub
===================== =====================
......
...@@ -517,6 +517,19 @@ _COMMON_META = { ...@@ -517,6 +517,19 @@ _COMMON_META = {
class Raft_Large_Weights(WeightsEnum): class Raft_Large_Weights(WeightsEnum):
"""The metrics reported here are as follows.
``epe`` is the "end-point-error" and indicates how far (in pixels) the
predicted flow is from its true value. This is averaged over all pixels
of all images. ``per_image_epe`` is similar, but the average is different:
the epe is first computed on each image independently, and then averaged
over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
in the original paper, and it's only used on Kitti. ``fl-all`` is also a
Kitti-specific metric, defined by the author of the dataset and used for the
Kitti leaderboard. It corresponds to the average of pixels whose epe is
either <3px, or <5% of flow's 2-norm.
"""
C_T_V1 = Weights( C_T_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT # Weights ported from https://github.com/princeton-vl/RAFT
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
...@@ -530,7 +543,9 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -530,7 +543,9 @@ class Raft_Large_Weights(WeightsEnum):
"Sintel-Train-Finalpass": {"epe": 2.7894}, "Sintel-Train-Finalpass": {"epe": 2.7894},
"Kitti-Train": {"per_image_epe": 5.0172, "fl_all": 17.4506}, "Kitti-Train": {"per_image_epe": 5.0172, "fl_all": 17.4506},
}, },
"_docs": """These weights were ported from the original paper. They are trained on Chairs + Things.""", "_docs": """These weights were ported from the original paper. They
are trained on :class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D`.""",
}, },
) )
...@@ -546,7 +561,9 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -546,7 +561,9 @@ class Raft_Large_Weights(WeightsEnum):
"Sintel-Train-Finalpass": {"epe": 2.7161}, "Sintel-Train-Finalpass": {"epe": 2.7161},
"Kitti-Train": {"per_image_epe": 4.5118, "fl_all": 16.0679}, "Kitti-Train": {"per_image_epe": 4.5118, "fl_all": 16.0679},
}, },
"_docs": """These weights were trained from scratch on Chairs + Things.""", "_docs": """These weights were trained from scratch on
:class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D`.""",
}, },
) )
...@@ -563,8 +580,14 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -563,8 +580,14 @@ class Raft_Large_Weights(WeightsEnum):
"Sintel-Test-Finalpass": {"epe": 3.18}, "Sintel-Test-Finalpass": {"epe": 3.18},
}, },
"_docs": """ "_docs": """
These weights were ported from the original paper. They are trained on Chairs + Things and fine-tuned on These weights were ported from the original paper. They are
Sintel (C+T+S+K+H). trained on :class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D` and fine-tuned on
Sintel. The Sintel fine-tuning step is a combination of
:class:`~torchvision.datasets.Sintel`,
:class:`~torchvision.datasets.KittiFlow`,
:class:`~torchvision.datasets.HD1K`, and
:class:`~torchvision.datasets.FlyingThings3D` (clean pass).
""", """,
}, },
) )
...@@ -581,7 +604,14 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -581,7 +604,14 @@ class Raft_Large_Weights(WeightsEnum):
"Sintel-Test-Finalpass": {"epe": 3.067}, "Sintel-Test-Finalpass": {"epe": 3.067},
}, },
"_docs": """ "_docs": """
These weights were trained from scratch on Chairs + Things and fine-tuned on Sintel (C+T+S+K+H). These weights were trained from scratch. They are
pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D` and then
fine-tuned on Sintel. The Sintel fine-tuning step is a
combination of :class:`~torchvision.datasets.Sintel`,
:class:`~torchvision.datasets.KittiFlow`,
:class:`~torchvision.datasets.HD1K`, and
:class:`~torchvision.datasets.FlyingThings3D` (clean pass).
""", """,
}, },
) )
...@@ -598,8 +628,12 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -598,8 +628,12 @@ class Raft_Large_Weights(WeightsEnum):
"Kitti-Test": {"fl_all": 5.10}, "Kitti-Test": {"fl_all": 5.10},
}, },
"_docs": """ "_docs": """
These weights were ported from the original paper. They are trained on Chairs + Things, fine-tuned on These weights were ported from the original paper. They are
Sintel and then on Kitti. pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D`,
fine-tuned on Sintel, and then fine-tuned on
:class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
step was described above.
""", """,
}, },
) )
...@@ -615,7 +649,12 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -615,7 +649,12 @@ class Raft_Large_Weights(WeightsEnum):
"Kitti-Test": {"fl_all": 5.19}, "Kitti-Test": {"fl_all": 5.19},
}, },
"_docs": """ "_docs": """
These weights were trained from scratch on Chairs + Things, fine-tuned on Sintel and then on Kitti. These weights were trained from scratch. They are
pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D`,
fine-tuned on Sintel, and then fine-tuned on
:class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
step was described above.
""", """,
}, },
) )
...@@ -624,6 +663,19 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -624,6 +663,19 @@ class Raft_Large_Weights(WeightsEnum):
class Raft_Small_Weights(WeightsEnum): class Raft_Small_Weights(WeightsEnum):
"""The metrics reported here are as follows.
``epe`` is the "end-point-error" and indicates how far (in pixels) the
predicted flow is from its true value. This is averaged over all pixels
of all images. ``per_image_epe`` is similar, but the average is different:
the epe is first computed on each image independently, and then averaged
over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
in the original paper, and it's only used on Kitti. ``fl-all`` is also a
Kitti-specific metric, defined by the author of the dataset and used for the
Kitti leaderboard. It corresponds to the average of pixels whose epe is
either <3px, or <5% of flow's 2-norm.
"""
C_T_V1 = Weights( C_T_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT # Weights ported from https://github.com/princeton-vl/RAFT
url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
...@@ -637,7 +689,9 @@ class Raft_Small_Weights(WeightsEnum): ...@@ -637,7 +689,9 @@ class Raft_Small_Weights(WeightsEnum):
"Sintel-Train-Finalpass": {"epe": 3.2790}, "Sintel-Train-Finalpass": {"epe": 3.2790},
"Kitti-Train": {"per_image_epe": 7.6557, "fl_all": 25.2801}, "Kitti-Train": {"per_image_epe": 7.6557, "fl_all": 25.2801},
}, },
"_docs": """These weights were ported from the original paper. They are trained on Chairs + Things.""", "_docs": """These weights were ported from the original paper. They
are trained on :class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D`.""",
}, },
) )
C_T_V2 = Weights( C_T_V2 = Weights(
...@@ -652,7 +706,9 @@ class Raft_Small_Weights(WeightsEnum): ...@@ -652,7 +706,9 @@ class Raft_Small_Weights(WeightsEnum):
"Sintel-Train-Finalpass": {"epe": 3.2831}, "Sintel-Train-Finalpass": {"epe": 3.2831},
"Kitti-Train": {"per_image_epe": 7.5978, "fl_all": 25.2369}, "Kitti-Train": {"per_image_epe": 7.5978, "fl_all": 25.2369},
}, },
"_docs": """These weights were trained from scratch on Chairs + Things.""", "_docs": """These weights were trained from scratch on
:class:`~torchvision.datasets.FlyingChairs` +
:class:`~torchvision.datasets.FlyingThings3D`.""",
}, },
) )
...@@ -750,13 +806,19 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * ...@@ -750,13 +806,19 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
Please see the example below for a tutorial on how to use this model. Please see the example below for a tutorial on how to use this model.
Args: Args:
weights(Raft_Large_weights, optional): The pretrained weights for the model weights(:class:`~torchvision.models.optical_flow.Raft_Large_Weights`, optional): The
progress (bool): If True, displays a progress bar of the download to stderr pretrained weights to use. See
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class :class:`~torchvision.models.optical_flow.Raft_Large_Weights`
to override any default. below for more details, and possible values. By default, no
pre-trained weights are used.
Returns: progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
RAFT: The model. **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
for more details about this class.
.. autoclass:: torchvision.models.optical_flow.Raft_Large_Weights
:members:
""" """
weights = Raft_Large_Weights.verify(weights) weights = Raft_Large_Weights.verify(weights)
...@@ -794,19 +856,24 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * ...@@ -794,19 +856,24 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) @handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT: def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT:
"""RAFT "small" model from """RAFT "small" model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`__.
Please see the example below for a tutorial on how to use this model. Please see the example below for a tutorial on how to use this model.
Args: Args:
weights(Raft_Small_weights, optional): The pretrained weights for the model weights(:class:`~torchvision.models.optical_flow.Raft_Small_Weights`, optional): The
progress (bool): If True, displays a progress bar of the download to stderr pretrained weights to use. See
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class :class:`~torchvision.models.optical_flow.Raft_Small_Weights`
to override any default. below for more details, and possible values. By default, no
pre-trained weights are used.
Returns: progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
RAFT: The model. **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
for more details about this class.
.. autoclass:: torchvision.models.optical_flow.Raft_Small_Weights
:members:
""" """
weights = Raft_Small_Weights.verify(weights) weights = Raft_Small_Weights.verify(weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment