Unverified Commit f6178728 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[gemini] fix init bugs for modules (#2047)

* [gemini] fix init bugs for modules

* fix bugs
parent 81e0da7f
...@@ -96,10 +96,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): ...@@ -96,10 +96,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times? FIXME(fjr) The module may be passed to this function multiple times?
""" """
if hasattr(module, '_colo_visited'):
return
name_list = [] name_list = []
for name, param in _named_params_with_replica(module): for name, param in _named_params_with_replica(module):
if isinstance(param, ColoTensor): if isinstance(param, ColoTensor):
...@@ -130,7 +126,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): ...@@ -130,7 +126,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
colo_param.shared_param_modules.append(submodule) colo_param.shared_param_modules.append(submodule)
module.to(self._device) module.to(self._device)
ColoModulize(module)
def post_process_colo_init_ctx(model: torch.nn.Module, def post_process_colo_init_ctx(model: torch.nn.Module,
......
...@@ -24,6 +24,11 @@ from tests.components_to_test import run_fwd_bwd ...@@ -24,6 +24,11 @@ from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed from tests.test_tensor.common_utils import debug_print, set_seed
# this model is large enough to slice to chunks
TEST_MODELS = ['gpt2']
# these models are too small, all parameters in these models are compacted into one chunk
EXAMPLE_MODELS = ['hanging_param_model', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers']
def check_param(model: ZeroDDP, torch_model: torch.nn.Module): def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False) zero_dict = model.state_dict(only_rank_0=False)
...@@ -40,10 +45,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): ...@@ -40,10 +45,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2)
# 'gpt2', 'bert',
TEST_MODELS = ['hanging_param_model', 'gpt2', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers']
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('model_name', TEST_MODELS) @parameterize('model_name', TEST_MODELS)
def exam_model_step(placement_policy, model_name: str): def exam_model_step(placement_policy, model_name: str):
...@@ -61,8 +62,6 @@ def exam_model_step(placement_policy, model_name: str): ...@@ -61,8 +62,6 @@ def exam_model_step(placement_policy, model_name: str):
with ColoInitContext(device=init_dev): with ColoInitContext(device=init_dev):
model = model_builder() model = model_builder()
post_process_colo_init_ctx(model, device=init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
...@@ -102,8 +101,8 @@ def exam_model_step(placement_policy, model_name: str): ...@@ -102,8 +101,8 @@ def exam_model_step(placement_policy, model_name: str):
check_param(model, torch_model) check_param(model, torch_model)
@parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('model_name', TEST_MODELS) @parameterize('model_name', EXAMPLE_MODELS)
def exam_tiny_example(placement_policy, model_name: str): def exam_tiny_example(placement_policy, model_name: str):
set_seed(2008) set_seed(2008)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
...@@ -119,8 +118,6 @@ def exam_tiny_example(placement_policy, model_name: str): ...@@ -119,8 +118,6 @@ def exam_tiny_example(placement_policy, model_name: str):
with ColoInitContext(device=init_dev): with ColoInitContext(device=init_dev):
model = model_builder() model = model_builder()
post_process_colo_init_ctx(model, device=init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment