Unverified Commit e6edcef4 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Allow custom docs for Weight enums and Weights fields (#5988)



* POC

* Update torchvision/models/resnet.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Fix tests

* ufmt

* Remove useless docstring
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 64b1e279
......@@ -320,7 +320,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
"The model builder above accepts the following values as the ``weights`` parameter.",
f"``{obj.__name__}.DEFAULT`` is equivalent to ``{obj.DEFAULT}``.",
]
if obj.__doc__ != "An enumeration.":
# We only show the custom enum doc if it was overriden. The default one from Python is "An enumeration"
lines.append("")
lines.append(obj.__doc__)
lines.append("")
for field in obj:
lines += [f"**{str(field)}**:", ""]
if field == obj.DEFAULT:
......@@ -335,10 +342,14 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
metrics = meta.pop("metrics", {})
meta_with_metrics = dict(meta, **metrics)
meta_with_metrics.pop("categories", None) # We don't want to document these, they can be too long
custom_docs = meta_with_metrics.pop("_docs", None) # Custom per-Weights docs
if custom_docs is not None:
lines += [custom_docs, ""]
for k, v in meta_with_metrics.items():
if k == "categories":
continue
elif k == "recipe":
if k == "recipe":
v = f"`link <{v}>`__"
table.append((str(k), str(v)))
table = tabulate(table, tablefmt="rst")
......
......@@ -90,6 +90,7 @@ def test_schema_meta_validation(model_fn):
"num_params",
"recipe",
"unquantized",
"_docs",
}
# mandatory fields for each computer vision task
classification_fields = {"categories", ("metrics", "acc@1"), ("metrics", "acc@5")}
......
......@@ -355,6 +355,9 @@ class ResNet50_Weights(WeightsEnum):
"acc@1": 76.130,
"acc@5": 92.862,
},
"_docs": """
These are standard weights using the basic recipe of the paper.
""",
},
)
IMAGENET1K_V2 = Weights(
......@@ -368,6 +371,10 @@ class ResNet50_Weights(WeightsEnum):
"acc@1": 80.858,
"acc@5": 95.434,
},
"_docs": """
These are improved weights, using TorchVision's `new recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
},
)
DEFAULT = IMAGENET1K_V2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment