Unverified Commit 970ba355 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add mask rcnn to instance segmentation docs (#5949)



* Add mask rcnn to instance segmentation

* Add docs a use nicolas suggestion

* remov param

* remov docsstring, edit docs

* Add one section level
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 62740807
...@@ -348,14 +348,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines): ...@@ -348,14 +348,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
lines.append("") lines.append("")
def generate_weights_table(module, table_name, metrics, include_pattern=None, exclude_pattern=None): def generate_weights_table(module, table_name, metrics, include_patterns=None, exclude_patterns=None):
weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")] weight_enums = [getattr(module, name) for name in dir(module) if name.endswith("_Weights")]
weights = [w for weight_enum in weight_enums for w in weight_enum] weights = [w for weight_enum in weight_enums for w in weight_enum]
if include_pattern is not None: if include_patterns is not None:
weights = [w for w in weights if include_pattern in str(w)] weights = [w for w in weights if any(p in str(w) for p in include_patterns)]
if exclude_pattern is not None: if exclude_patterns is not None:
weights = [w for w in weights if exclude_pattern not in str(w)] weights = [w for w in weights if all(p not in str(w) for p in exclude_patterns)]
metrics_keys, metrics_names = zip(*metrics) metrics_keys, metrics_names = zip(*metrics)
column_names = ["Weight"] + list(metrics_names) + ["Params", "Recipe"] column_names = ["Weight"] + list(metrics_names) + ["Params", "Recipe"]
...@@ -383,13 +383,19 @@ def generate_weights_table(module, table_name, metrics, include_pattern=None, ex ...@@ -383,13 +383,19 @@ def generate_weights_table(module, table_name, metrics, include_pattern=None, ex
generate_weights_table(module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")]) generate_weights_table(module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])
generate_weights_table( generate_weights_table(
module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")], exclude_pattern="Keypoint" module=M.detection, table_name="detection", metrics=[("box_map", "Box MAP")], exclude_patterns=["Mask", "Keypoint"]
)
generate_weights_table(
module=M.detection,
table_name="instance_segmentation",
metrics=[("box_map", "Box MAP"), ("mask_map", "Mask MAP")],
include_patterns=["Mask"],
) )
generate_weights_table( generate_weights_table(
module=M.detection, module=M.detection,
table_name="detection_keypoint", table_name="detection_keypoint",
metrics=[("box_map", "Box MAP"), ("kp_map", "Keypoint MAP")], metrics=[("box_map", "Box MAP"), ("kp_map", "Keypoint MAP")],
include_pattern="Keypoint", include_patterns=["Keypoint"],
) )
generate_weights_table( generate_weights_table(
module=M.segmentation, table_name="segmentation", metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")] module=M.segmentation, table_name="segmentation", metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")]
......
Faster R-CNN Faster R-CNN
========== ============
.. currentmodule:: torchvision.models.detection .. currentmodule:: torchvision.models.detection
......
...@@ -89,12 +89,15 @@ All models are evaluated on COCO val2017: ...@@ -89,12 +89,15 @@ All models are evaluated on COCO val2017:
Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================
Object Detection Object Detection
================ ----------------
.. currentmodule:: torchvision.models.detection .. currentmodule:: torchvision.models.detection
The following detection models are available, with or without pre-trained The following object detection models are available, with or without pre-trained
weights: weights:
.. toctree:: .. toctree::
...@@ -102,20 +105,38 @@ weights: ...@@ -102,20 +105,38 @@ weights:
models/faster_rcnn models/faster_rcnn
models/fcos models/fcos
models/mask_rcnn
models/retinanet models/retinanet
models/ssdlite models/ssdlite
Table of all available detection weights Table of all available Object detection weights
---------------------------------------- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Box MAPs are reported on COCO Box MAPs are reported on COCO
.. include:: generated/detection_table.rst .. include:: generated/detection_table.rst
Instance Segmentation
---------------------
.. currentmodule:: torchvision.models.detection
The following instance segmentation models are available, with or without pre-trained
weights:
.. toctree::
:maxdepth: 1
models/mask_rcnn
Table of all available Instance segmentation weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Box and Mask MAPs are reported on COCO
.. include:: generated/instance_segmentation_table.rst
Keypoint detection Keypoint Detection
================== ------------------
.. currentmodule:: torchvision.models.detection .. currentmodule:: torchvision.models.detection
...@@ -128,7 +149,7 @@ pre-trained weights: ...@@ -128,7 +149,7 @@ pre-trained weights:
models/keypoint_rcnn models/keypoint_rcnn
Table of all available Keypoint detection weights Table of all available Keypoint detection weights
------------------------------------------------- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Box and Keypoint MAPs are reported on COCO: Box and Keypoint MAPs are reported on COCO:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment