Unverified Commit f0aa191f authored by ver217's avatar ver217 Committed by GitHub
Browse files

[gemini] fix colo_init_context (#2683)

parent 5cd8cae0
......@@ -32,7 +32,7 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec: Optional[Any] = None) -> ColoParameter:
if isinstance(param, ColoParameter):
if type(param) is ColoParameter:
return param
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
......@@ -102,7 +102,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
"""
name_list = []
for name, param in _named_params_with_replica(module):
if isinstance(param, ColoTensor):
if type(param) is ColoParameter:
continue
split = name.rfind('.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment