Unverified Commit f8eec98f authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[tensor] fixed non-serializable colo parameter during model checkpointing (#1153)

parent ffa025e1
from .utils import InsertPostInitMethodToModuleSubClasses from .utils import InsertPostInitMethodToModuleSubClasses
import torch import torch
from colossalai.tensor import ColoTensor, ColoParameter from colossalai.tensor import ColoTensor, ColoParameter, distspec, TensorSpec
from colossalai.nn.parallel.layers import register_colo_module, \ from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding ColoLinear, ColoEmbedding
from copy import copy
from torch import nn from torch import nn
from typing import Iterator, Tuple, Union from typing import Iterator, Tuple, Union
from functools import partialmethod
# find named_params includes replica # find named_params includes replica
...@@ -34,6 +34,38 @@ def ColoModulize(module): ...@@ -34,6 +34,38 @@ def ColoModulize(module):
module._colo_visited = True module._colo_visited = True
def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_dict_func=None):
# build param to spec mapping
mapping = dict()
# gather all params
has_dist_parameter = False
with torch.no_grad():
for param in self.parameters():
if isinstance(param, ColoParameter) and param.has_spec():
has_dist_parameter = True
mapping[id(param)] = copy(param.spec)
param.set_spec(TensorSpec(distspec.replicate()))
# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
# new tensors, but when keep_vars = True, the recovery of spec will be reflected
# in the `ret`, such that the final state dict will still contain process group,
# raising exception as it is not serializable
assert not (keep_vars and has_dist_parameter), 'keep_vars cannot be True when there are distributed ColoParameters.'
ret = state_dict_func(self, destination, prefix, keep_vars)
# recover
with torch.no_grad():
for param in self.parameters():
param_id = id(param)
if param_id in mapping:
spec = mapping[id(param)]
param.set_spec(spec)
return ret
class ColoInitContext(InsertPostInitMethodToModuleSubClasses): class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')): def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
...@@ -52,6 +84,10 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): ...@@ -52,6 +84,10 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
register_colo_module(torch.nn.Linear, ColoLinear()) register_colo_module(torch.nn.Linear, ColoLinear())
register_colo_module(torch.nn.Embedding, ColoEmbedding()) register_colo_module(torch.nn.Embedding, ColoEmbedding())
def _pre_context_exec(self):
self.state_dict_func = nn.Module.state_dict
nn.Module.state_dict = partialmethod(colo_state_dict, state_dict_func=self.state_dict_func)
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
""" """
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment