"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "e80ebf6bc44df480dc4a6ea21694c6653d0936fb"
Unverified Commit ea13a201 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[polish] polish code for get_static_torch_model (#2405)

* [gemini] polish code

* [testing] remove code

* [gemini] make more robust
parent 551cafec
...@@ -334,10 +334,9 @@ class ZeroDDP(ColoDDP): ...@@ -334,10 +334,9 @@ class ZeroDDP(ColoDDP):
self.grads_device[tensor] = device self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True): def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
r""" """
Args: Args:
strict (bool): whether to reture the whole model state strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()`
as the original pytorch state_dict()
Returns: Returns:
dict: dict:
...@@ -349,25 +348,24 @@ class ZeroDDP(ColoDDP): ...@@ -349,25 +348,24 @@ class ZeroDDP(ColoDDP):
['bias', 'weight'] ['bias', 'weight']
""" """
if strict: if strict:
return get_static_torch_model(zero_ddp_model=self, device=get_current_device(), assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
only_rank_0=only_rank_0).state_dict(destination=destination, torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0)
prefix=prefix, return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
keep_vars=keep_vars)
return self._non_strict_state_dict(destination=destination, return self._non_strict_state_dict(destination=destination,
prefix=prefix, prefix=prefix,
keep_vars=keep_vars, keep_vars=keep_vars,
only_rank_0=only_rank_0) only_rank_0=only_rank_0)
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
r"""Returns a dictionary containing a whole state of the module. """Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are Both parameters and persistent buffers (e.g. running averages) are included.
included. Keys are corresponding parameter and buffer names. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included. Parameters and buffers set to ``None`` are not included.
Warning: The non strict state dict would ignore the parameters if the Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
tensors of the parameters are shared with other parameters which are shared with other parameters which have been included in the dictionary.
have been included in the dictionary. When you need to load the state dict, you should set the argument `strict` to False.
Returns: Returns:
dict: dict:
......
...@@ -47,17 +47,16 @@ def _get_shallow_copy_model(model: nn.Module): ...@@ -47,17 +47,16 @@ def _get_shallow_copy_model(model: nn.Module):
"""Get a shallow copy of the given model. Each submodule is different from the original submodule. """Get a shallow copy of the given model. Each submodule is different from the original submodule.
But the new submodule and the old submodule share all attributes. But the new submodule and the old submodule share all attributes.
""" """
name_to_module = dict() old_to_new = dict()
for name, module in _get_dfs_module_list(model): for name, module in _get_dfs_module_list(model):
new_module = copy(module) new_module = copy(module)
new_module._modules = OrderedDict() new_module._modules = OrderedDict()
for subname, submodule in module._modules.items(): for subname, submodule in module._modules.items():
if submodule is None: if submodule is None:
continue continue
full_name = name + ('.' if name else '') + subname setattr(new_module, subname, old_to_new[submodule])
setattr(new_module, subname, name_to_module[full_name]) old_to_new[module] = new_module
name_to_module[name] = new_module return old_to_new[model]
return name_to_module['']
def get_static_torch_model(zero_ddp_model, def get_static_torch_model(zero_ddp_model,
......
...@@ -31,8 +31,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): ...@@ -31,8 +31,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
for key, value in torch_dict.items(): for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it # key is 'module.model.PARAMETER', so we truncate it
key = key[7:] key = key[7:]
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
......
...@@ -36,8 +36,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): ...@@ -36,8 +36,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
for key, value in torch_dict.items(): for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it # key is 'module.model.PARAMETER', so we truncate it
key = key[7:] key = key[7:]
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
......
...@@ -45,8 +45,6 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str): ...@@ -45,8 +45,6 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
torch_dict = torch_model.state_dict() torch_dict = torch_model.state_dict()
for key, value in torch_dict.items(): for key, value in torch_dict.items():
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
...@@ -84,8 +82,6 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): ...@@ -84,8 +82,6 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
zero_dict = model.state_dict(only_rank_0=False) zero_dict = model.state_dict(only_rank_0=False)
for key, value in torch_dict.items(): for key, value in torch_dict.items():
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
......
...@@ -27,8 +27,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup): ...@@ -27,8 +27,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
for key, value in torch_dict.items(): for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it # key is 'module.model.PARAMETER', so we truncate it
key = key[7:] key = key[7:]
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment