Unverified Commit 9880fd2c authored by eric8607242's avatar eric8607242 Committed by GitHub
Browse files

Fix state_dict key missing issue of the ZeroDDP (#2363)

* Fix state_dict output for ZeroDDP duplicated parameters

* Rewrite state_dict based on get_static_torch_model

* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)
parent ce08661e
...@@ -18,6 +18,7 @@ from colossalai.utils import get_current_device ...@@ -18,6 +18,7 @@ from colossalai.utils import get_current_device
from colossalai.zero.utils.gemini_hook import GeminiZeROHook from colossalai.zero.utils.gemini_hook import GeminiZeROHook
from .reducer import Reducer from .reducer import Reducer
from .utils import get_static_torch_model
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
...@@ -251,6 +252,7 @@ class ZeroDDP(ColoDDP): ...@@ -251,6 +252,7 @@ class ZeroDDP(ColoDDP):
pin_memory=pin_memory) pin_memory=pin_memory)
self.fp32_params.append(fp32_p) self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups() self.chunk_manager.close_all_groups()
self._cast_buffers() self._cast_buffers()
...@@ -331,12 +333,11 @@ class ZeroDDP(ColoDDP): ...@@ -331,12 +333,11 @@ class ZeroDDP(ColoDDP):
for tensor in chunk.get_tensors(): for tensor in chunk.get_tensors():
self.grads_device[tensor] = device self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
r"""Returns a dictionary containing a whole state of the module. r"""
Args:
Both parameters and persistent buffers (e.g. running averages) are strict (bool): whether to reture the whole model state
included. Keys are corresponding parameter and buffer names. as the original pytorch state_dict()
Parameters and buffers set to ``None`` are not included.
Returns: Returns:
dict: dict:
...@@ -346,7 +347,31 @@ class ZeroDDP(ColoDDP): ...@@ -346,7 +347,31 @@ class ZeroDDP(ColoDDP):
>>> module.state_dict().keys() >>> module.state_dict().keys()
['bias', 'weight'] ['bias', 'weight']
"""
if strict:
return get_static_torch_model(zero_ddp_model=self, device=get_current_device(),
only_rank_0=only_rank_0).state_dict(destination=destination,
prefix=prefix,
keep_vars=keep_vars)
return self._non_strict_state_dict(destination=destination,
prefix=prefix,
keep_vars=keep_vars,
only_rank_0=only_rank_0)
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
r"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Warning: The non strict state dict would ignore the parameters if the
tensors of the parameters are shared with other parameters which
have been included in the dictionary.
Returns:
dict:
a dictionary containing a whole state of the module
""" """
if destination is None: if destination is None:
destination = OrderedDict() destination = OrderedDict()
......
...@@ -60,17 +60,17 @@ def _get_shallow_copy_model(model: nn.Module): ...@@ -60,17 +60,17 @@ def _get_shallow_copy_model(model: nn.Module):
return name_to_module[''] return name_to_module['']
def get_static_torch_model(gemini_ddp_model, def get_static_torch_model(zero_ddp_model,
device=torch.device("cpu"), device=torch.device("cpu"),
dtype=torch.float32, dtype=torch.float32,
only_rank_0=True) -> torch.nn.Module: only_rank_0=True) -> torch.nn.Module:
"""Get a static torch.nn.Module model from the given GeminiDDP module. """Get a static torch.nn.Module model from the given ZeroDDP module.
You should notice that the original GeminiDDP model is not modified. You should notice that the original ZeroDDP model is not modified.
Thus, you can use the original model in further training. Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors. But you should not use the returned torch model to train, this can cause unexpected errors.
Args: Args:
gemini_ddp_model (GeminiDDP): a gemini ddp model zero_ddp_model (ZeroDDP): a zero ddp model
device (torch.device): the device of the final torch model device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the coverted torch model only_rank_0 (bool): if True, only rank0 has the coverted torch model
...@@ -78,11 +78,11 @@ def get_static_torch_model(gemini_ddp_model, ...@@ -78,11 +78,11 @@ def get_static_torch_model(gemini_ddp_model,
Returns: Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
""" """
from colossalai.nn.parallel import GeminiDDP from colossalai.nn.parallel import ZeroDDP
assert isinstance(gemini_ddp_model, GeminiDDP) assert isinstance(zero_ddp_model, ZeroDDP)
state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0) state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False)
colo_model = gemini_ddp_model.module colo_model = zero_ddp_model.module
torch_model = _get_shallow_copy_model(colo_model) torch_model = _get_shallow_copy_model(colo_model)
if not only_rank_0 or dist.get_rank() == 0: if not only_rank_0 or dist.get_rank() == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment