Unverified Commit 6724e791 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Check that the model can be inspected upon registration (#13743)

parent 03f48b3d
...@@ -347,6 +347,10 @@ class _ModelRegistry: ...@@ -347,6 +347,10 @@ class _ModelRegistry:
when importing the model and thus the related error when importing the model and thus the related error
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`. :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
""" """
if not isinstance(model_arch, str):
msg = f"`model_arch` should be a string, not a {type(model_arch)}"
raise TypeError(msg)
if model_arch in self.models: if model_arch in self.models:
logger.warning( logger.warning(
"Model architecture %s is already registered, and will be " "Model architecture %s is already registered, and will be "
...@@ -360,8 +364,18 @@ class _ModelRegistry: ...@@ -360,8 +364,18 @@ class _ModelRegistry:
raise ValueError(msg) raise ValueError(msg)
model = _LazyRegisteredModel(*split_str) model = _LazyRegisteredModel(*split_str)
else:
try:
model.inspect_model_cls()
except Exception as exc:
msg = f"Unable to inspect model {model_cls}"
raise RuntimeError(msg) from exc
elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
model = _RegisteredModel.from_model_cls(model_cls) model = _RegisteredModel.from_model_cls(model_cls)
else:
msg = ("`model_cls` should be a string or PyTorch model class, "
f"not a {type(model_arch)}")
raise TypeError(msg)
self.models[model_arch] = model self.models[model_arch] = model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment