"docs/modules/vscode:/vscode.git/clone" did not exist on "dbf06b504b525c7f6680c5709b63df6413616d2e"
Unverified Commit d425f007 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Start doc revamp for detection models (#5876)

* Start doc revamp for detection models

* Minor cleanup

* Use list of tuples for metrics
parent cc53cd01
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
import os import os
import textwrap import textwrap
from copy import copy
from pathlib import Path from pathlib import Path
import pytorch_sphinx_theme import pytorch_sphinx_theme
...@@ -330,7 +331,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines): ...@@ -330,7 +331,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
# the `meta` dict contains another embedded `metrics` dict. To # the `meta` dict contains another embedded `metrics` dict. To
# simplify the table generation below, we create the # simplify the table generation below, we create the
# `meta_with_metrics` dict, where the metrics dict has been "flattened" # `meta_with_metrics` dict, where the metrics dict has been "flattened"
meta = field.meta meta = copy(field.meta)
metrics = meta.pop("metrics", {}) metrics = meta.pop("metrics", {})
meta_with_metrics = dict(meta, **metrics) meta_with_metrics = dict(meta, **metrics)
...@@ -346,17 +347,18 @@ def inject_weight_metadata(app, what, name, obj, options, lines): ...@@ -346,17 +347,18 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
lines.append("") lines.append("")
def generate_classification_table(): def generate_weights_table(module, table_name, metrics):
weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")]
weight_enums = [getattr(M, name) for name in dir(M) if name.endswith("_Weights")]
weights = [w for weight_enum in weight_enums for w in weight_enum] weights = [w for weight_enum in weight_enums for w in weight_enum]
column_names = ("**Weight**", "**Acc@1**", "**Acc@5**", "**Params**", "**Recipe**") metrics_keys, metrics_names = zip(*metrics)
column_names = ["Weight"] + list(metrics_names) + ["Params", "Recipe"]
column_names = [f"**{name}**" for name in column_names] # Add bold
content = [ content = [
( (
f":class:`{w} <{type(w).__name__}>`", f":class:`{w} <{type(w).__name__}>`",
w.meta["metrics"]["acc@1"], *(w.meta["metrics"][metric] for metric in metrics_keys),
w.meta["metrics"]["acc@5"],
f"{w.meta['num_params']/1e6:.1f}M", f"{w.meta['num_params']/1e6:.1f}M",
f"`link <{w.meta['recipe']}>`__", f"`link <{w.meta['recipe']}>`__",
) )
...@@ -366,13 +368,14 @@ def generate_classification_table(): ...@@ -366,13 +368,14 @@ def generate_classification_table():
generated_dir = Path("generated") generated_dir = Path("generated")
generated_dir.mkdir(exist_ok=True) generated_dir.mkdir(exist_ok=True)
with open(generated_dir / "classification_table.rst", "w+") as table_file: with open(generated_dir / f"{table_name}_table.rst", "w+") as table_file:
table_file.write(".. table::\n") table_file.write(".. table::\n")
table_file.write(" :widths: 100 10 10 20 10\n\n") table_file.write(f" :widths: 100 {'20 ' * len(metrics_names)} 20 10\n\n")
table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n") table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")
generate_classification_table() generate_weights_table(module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])
generate_weights_table(module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")])
def setup(app): def setup(app):
......
RetinaNet
=========
.. currentmodule:: torchvision.models.detection
The RetinaNet model is based on the `Focal Loss for Dense Object Detection
<https://arxiv.org/abs/1708.02002>`__ paper.
Model builders
--------------
The following model builders can be used to instantiate a RetinaNet model, with or
without pre-trained weights. All the model buidlers internally rely on the
``torchvision.models.detection.retinanet.RetinaNet`` base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
retinanet_resnet50_fpn
retinanet_resnet50_fpn_v2
...@@ -59,4 +59,19 @@ Accuracies are reported on ImageNet ...@@ -59,4 +59,19 @@ Accuracies are reported on ImageNet
Object Detection, Instance Segmentation and Person Keypoint Detection Object Detection, Instance Segmentation and Person Keypoint Detection
===================================================================== =====================================================================
TODO: Something similar to classification models: list of models + table of weights .. currentmodule:: torchvision.models.detection
The following detection models are available, with or without pre-trained
weights:
.. toctree::
:maxdepth: 1
models/retinanet
Table of all available detection weights
----------------------------------------
Box MAPs are reported on COCO
.. include:: generated/detection_table.rst
...@@ -727,7 +727,7 @@ def retinanet_resnet50_fpn( ...@@ -727,7 +727,7 @@ def retinanet_resnet50_fpn(
""" """
Constructs a RetinaNet model with a ResNet-50-FPN backbone. Constructs a RetinaNet model with a ResNet-50-FPN backbone.
Reference: `"Focal Loss for Dense Object Detection" <https://arxiv.org/abs/1708.02002>`_. Reference: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes. image, and should be in ``0-1`` range. Different images can have different sizes.
...@@ -763,13 +763,21 @@ def retinanet_resnet50_fpn( ...@@ -763,13 +763,21 @@ def retinanet_resnet50_fpn(
>>> predictions = model(x) >>> predictions = model(x)
Args: Args:
weights (RetinaNet_ResNet50_FPN_Weights, optional): The pretrained weights for the model weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
progress (bool): If True, displays a progress bar of the download to stderr pretrained weights to use. See
:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
num_classes (int, optional): number of output classes of the model (including the background) num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
the backbone.
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3. passed (the default) this value is set to 3.
.. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
:members:
""" """
weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
...@@ -811,19 +819,27 @@ def retinanet_resnet50_fpn_v2( ...@@ -811,19 +819,27 @@ def retinanet_resnet50_fpn_v2(
""" """
Constructs an improved RetinaNet model with a ResNet-50-FPN backbone. Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.
Reference: `"Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection" Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
<https://arxiv.org/abs/1912.02424>`_. <https://arxiv.org/abs/1912.02424>`_.
:func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details. :func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.
Args: Args:
weights (RetinaNet_ResNet50_FPN_V2_Weights, optional): The pretrained weights for the model weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
progress (bool): If True, displays a progress bar of the download to stderr pretrained weights to use. See
:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
num_classes (int, optional): number of output classes of the model (including the background) num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
the backbone.
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3. passed (the default) this value is set to 3.
.. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
:members:
""" """
weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights) weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment