Unverified Commit 6754f1b7 authored by Ziyue Jiang's avatar Ziyue Jiang Committed by GitHub
Browse files

fix module utils bug (#1066)

parent a0064407
...@@ -14,7 +14,7 @@ def register_colo_module(module_type: type, colo_module: ColoModule): ...@@ -14,7 +14,7 @@ def register_colo_module(module_type: type, colo_module: ColoModule):
def is_colo_module(module: torch.nn.Module): def is_colo_module(module: torch.nn.Module):
global _COLOSSAL_MODULES global _COLOSSAL_MODULES
for module_type in _COLOSSAL_MODULES.keys(): for module_type in _COLOSSAL_MODULES.keys():
if isinstance(type(module), module_type): if isinstance(module, module_type):
return True return True
return False return False
...@@ -23,7 +23,7 @@ def get_colo_module(module: torch.nn.Module): ...@@ -23,7 +23,7 @@ def get_colo_module(module: torch.nn.Module):
global _COLOSSAL_MODULES global _COLOSSAL_MODULES
if is_colo_module(module): if is_colo_module(module):
for module_type, colo_module in _COLOSSAL_MODULES.items(): for module_type, colo_module in _COLOSSAL_MODULES.items():
if isinstance(type(module), module_type): if isinstance(module, module_type):
return colo_module return colo_module
else: else:
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment