"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "ceb47f1d7678d2f155144abd3e0eefb24684d35e"
Unverified Commit 62740807 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Document Keypoint RCNN separately (#5933)

* Document Keypoint RCNN separately

* Move Keypoint detection into its own section

* ufmt
parent a5c86ffa
...@@ -348,10 +348,15 @@ def inject_weight_metadata(app, what, name, obj, options, lines): ...@@ -348,10 +348,15 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
lines.append("") lines.append("")
def generate_weights_table(module, table_name, metrics): def generate_weights_table(module, table_name, metrics, include_pattern=None, exclude_pattern=None):
weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")] weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")]
weights = [w for weight_enum in weight_enums for w in weight_enum] weights = [w for weight_enum in weight_enums for w in weight_enum]
if include_pattern is not None:
weights = [w for w in weights if include_pattern in str(w)]
if exclude_pattern is not None:
weights = [w for w in weights if exclude_pattern not in str(w)]
metrics_keys, metrics_names = zip(*metrics) metrics_keys, metrics_names = zip(*metrics)
column_names = ["Weight"] + list(metrics_names) + ["Params", "Recipe"] column_names = ["Weight"] + list(metrics_names) + ["Params", "Recipe"]
column_names = [f"**{name}**" for name in column_names] # Add bold column_names = [f"**{name}**" for name in column_names] # Add bold
...@@ -377,7 +382,15 @@ def generate_weights_table(module, table_name, metrics): ...@@ -377,7 +382,15 @@ def generate_weights_table(module, table_name, metrics):
generate_weights_table(module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")]) generate_weights_table(module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])
generate_weights_table(module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")]) generate_weights_table(
module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")], exclude_pattern="Keypoint"
)
generate_weights_table(
module=M.detection,
table_name="detection_keypoint",
metrics=[("box_map", "Box MAP"), ("kp_map", "Keypoint MAP")],
include_pattern="Keypoint",
)
generate_weights_table( generate_weights_table(
module=M.segmentation, table_name="segmentation", metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")] module=M.segmentation, table_name="segmentation", metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")]
) )
......
Keypoint R-CNN
==============
.. currentmodule:: torchvision.models.detection
The Keypoint R-CNN model is based on the `Mask R-CNN
<https://arxiv.org/abs/1703.06870>`__ paper.
Model builders
--------------
The following model builders can be used to instantiate a Keypoint R-CNN model,
with or without pre-trained weights. All the model builders internally rely on
the ``torchvision.models.detection.KeypointRCNN`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/keypoint_rcnn.py>`__
for more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
keypointrcnn_resnet50_fpn
...@@ -89,8 +89,8 @@ All models are evaluated on COCO val2017: ...@@ -89,8 +89,8 @@ All models are evaluated on COCO val2017:
Object Detection, Instance Segmentation and Person Keypoint Detection Object Detection
===================================================================== ================
.. currentmodule:: torchvision.models.detection .. currentmodule:: torchvision.models.detection
...@@ -114,6 +114,27 @@ Box MAPs are reported on COCO ...@@ -114,6 +114,27 @@ Box MAPs are reported on COCO
.. include:: generated/detection_table.rst .. include:: generated/detection_table.rst
Keypoint detection
==================
.. currentmodule:: torchvision.models.detection
The following keypoint detection models are available, with or without
pre-trained weights:
.. toctree::
:maxdepth: 1
models/keypoint_rcnn
Table of all available Keypoint detection weights
-------------------------------------------------
Box and Keypoint MAPs are reported on COCO:
.. include:: generated/detection_keypoint_table.rst
Video Classification Video Classification
==================== ====================
......
...@@ -366,7 +366,7 @@ def keypointrcnn_resnet50_fpn( ...@@ -366,7 +366,7 @@ def keypointrcnn_resnet50_fpn(
""" """
Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
Reference: `"Mask R-CNN" <https://arxiv.org/abs/1703.06870>`_. Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes. image, and should be in ``0-1`` range. Different images can have different sizes.
...@@ -410,14 +410,22 @@ def keypointrcnn_resnet50_fpn( ...@@ -410,14 +410,22 @@ def keypointrcnn_resnet50_fpn(
>>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11) >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
Args: Args:
weights (KeypointRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int, optional): number of output classes of the model (including the background) num_classes (int, optional): number of output classes of the model (including the background)
num_keypoints (int, optional): number of keypoints num_keypoints (int, optional): number of keypoints
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
pretrained weights for the backbone.
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3. passed (the default) this value is set to 3.
.. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
:members:
""" """
weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone) weights_backbone = ResNet50_Weights.verify(weights_backbone)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment