Unverified Commit fad0e273 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Fix for type checking failure on custom callables (#790)



fix type checking in checkpointing to assume that there must be TE modules in custom callables
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent 08e5e4b1
...@@ -521,9 +521,12 @@ def has_te_modules(network): ...@@ -521,9 +521,12 @@ def has_te_modules(network):
for module in network.modules(): for module in network.modules():
if any(isinstance(module, te_class) for te_class in te_classes_list): if any(isinstance(module, te_class) for te_class in te_classes_list):
return True return True
return False return False
# Cannot check for TE modules inside a custom class/callable that's not a torch.nn.Module,
# so just assume that it has TE modules just to be safe.
return True
def checkpoint( def checkpoint(
function: Callable, function: Callable,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment