Unverified Commit e6edcef4 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Allow custom docs for Weight enums and Weights fields (#5988)



* POC

* Update torchvision/models/resnet.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Fix tests

* ufmt

* Remove useless docstring
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 64b1e279
...@@ -320,7 +320,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines): ...@@ -320,7 +320,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
"The model builder above accepts the following values as the ``weights`` parameter.", "The model builder above accepts the following values as the ``weights`` parameter.",
f"``{obj.__name__}.DEFAULT`` is equivalent to ``{obj.DEFAULT}``.", f"``{obj.__name__}.DEFAULT`` is equivalent to ``{obj.DEFAULT}``.",
] ]
if obj.__doc__ != "An enumeration.":
# We only show the custom enum doc if it was overriden. The default one from Python is "An enumeration"
lines.append("")
lines.append(obj.__doc__)
lines.append("") lines.append("")
for field in obj: for field in obj:
lines += [f"**{str(field)}**:", ""] lines += [f"**{str(field)}**:", ""]
if field == obj.DEFAULT: if field == obj.DEFAULT:
...@@ -335,10 +342,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines): ...@@ -335,10 +342,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
metrics = meta.pop("metrics", {}) metrics = meta.pop("metrics", {})
meta_with_metrics = dict(meta, **metrics) meta_with_metrics = dict(meta, **metrics)
meta_with_metrics.pop("categories", None) # We don't want to document these, they can be too long
custom_docs = meta_with_metrics.pop("_docs", None) # Custom per-Weights docs
if custom_docs is not None:
lines += [custom_docs, ""]
for k, v in meta_with_metrics.items(): for k, v in meta_with_metrics.items():
if k == "categories": if k == "recipe":
continue
elif k == "recipe":
v = f"`link <{v}>`__" v = f"`link <{v}>`__"
table.append((str(k), str(v))) table.append((str(k), str(v)))
table = tabulate(table, tablefmt="rst") table = tabulate(table, tablefmt="rst")
......
...@@ -90,6 +90,7 @@ def test_schema_meta_validation(model_fn): ...@@ -90,6 +90,7 @@ def test_schema_meta_validation(model_fn):
"num_params", "num_params",
"recipe", "recipe",
"unquantized", "unquantized",
"_docs",
} }
# mandatory fields for each computer vision task # mandatory fields for each computer vision task
classification_fields = {"categories", ("metrics", "acc@1"), ("metrics", "acc@5")} classification_fields = {"categories", ("metrics", "acc@1"), ("metrics", "acc@5")}
......
...@@ -355,6 +355,9 @@ class ResNet50_Weights(WeightsEnum): ...@@ -355,6 +355,9 @@ class ResNet50_Weights(WeightsEnum):
"acc@1": 76.130, "acc@1": 76.130,
"acc@5": 92.862, "acc@5": 92.862,
}, },
"_docs": """
These are standard weights using the basic recipe of the paper.
""",
}, },
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
...@@ -368,6 +371,10 @@ class ResNet50_Weights(WeightsEnum): ...@@ -368,6 +371,10 @@ class ResNet50_Weights(WeightsEnum):
"acc@1": 80.858, "acc@1": 80.858,
"acc@5": 95.434, "acc@5": 95.434,
}, },
"_docs": """
These are improved weights, using TorchVision's `new recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
}, },
) )
DEFAULT = IMAGENET1K_V2 DEFAULT = IMAGENET1K_V2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment