Unverified Commit 828b9e5e authored by ver217's avatar ver217 Committed by GitHub
Browse files

[hotfix] fix zero optim save/load state dict (#1381)

parent b6fd165f
...@@ -104,8 +104,8 @@ class ProcessGroup: ...@@ -104,8 +104,8 @@ class ProcessGroup:
def set_cpu_groups(self): def set_cpu_groups(self):
if self.has_cpu_groups: if self.has_cpu_groups:
return return
self.logger.info( # self.logger.info(
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}') # f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo') PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
self._has_cpu_groups = True self._has_cpu_groups = True
......
...@@ -8,6 +8,9 @@ from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler ...@@ -8,6 +8,9 @@ from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable from colossalai.utils import get_current_device, disposable
from collections import defaultdict, abc as container_abcs
from copy import deepcopy
from itertools import chain
class OptimState(Enum): class OptimState(Enum):
...@@ -191,22 +194,105 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -191,22 +194,105 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.chunk_manager.add_extern_static_tensor(val) self.chunk_manager.add_extern_static_tensor(val)
def state_dict(self): def state_dict(self):
r"""Returns the state of the optimizer as a :class:`dict`. For DP rank != 0, this function returns None.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a list containing all parameter groups where each
parameter group is a dict
"""
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
if not self.chunk_manager.enable_distributed_storage and not is_rank_0:
return
optim_state_dict = super().state_dict() optim_state_dict = super().state_dict()
scaler_state_dict = self.grad_scaler.state_dict() scaler_state_dict = self.grad_scaler.state_dict()
optim_state_dict['scaler'] = scaler_state_dict optim_state_dict['scaler'] = scaler_state_dict
if not self.chunk_manager.enable_distributed_storage:
return optim_state_dict return optim_state_dict
local_state = {k: convert_state_dict_to_cpu(v) for k, v in optim_state_dict['state'].items() if len(v) > 0}
if not self.chunk_manager.process_group.has_cpu_groups:
self.chunk_manager.process_group.set_cpu_groups()
dst_rank = self.chunk_manager.process_group.dp_rank_list()[0]
output = [None for _ in range(self.chunk_manager.process_group.dp_world_size())]
dist.gather_object(local_state,
output if self.chunk_manager.process_group.dp_local_rank() == 0 else None,
dst=dst_rank,
group=self.chunk_manager.process_group.cpu_dp_process_group())
if not is_rank_0:
return
for state in output:
optim_state_dict['state'].update(state)
return optim_state_dict
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
def load_state_dict(self, *args, **kwargs): Args:
if 'scaler' not in args[0]: state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
if 'scaler' not in state_dict:
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
else: else:
scaler_state_dict = args[0].pop('scaler') self.grad_scaler.load_state_dict(deepcopy(state_dict['scaler']))
self.grad_scaler.load_state_dict(scaler_state_dict)
super().load_state_dict(*args, **kwargs) # Validate the state_dict
for group in self.optim.param_groups: groups = self.param_groups
for p in group['params']: saved_groups = deepcopy(state_dict['param_groups'])
state = self.optim.state[p]
for k, v in state.items(): if len(groups) != len(saved_groups):
if isinstance(v, torch.Tensor): raise ValueError("loaded state dict has a different number of "
state[k] = v.to(dtype=self.fp16_param_to_fp32_param[p].dtype, "parameter groups")
device=self.fp16_param_to_fp32_param[p].device) param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
# Update the state
id_map = {
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
)), chain.from_iterable((g['params'] for g in groups)))
}
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = self.fp16_param_to_fp32_param[id_map[k]]
if param.storage().size() > 0:
state[param] = cast(param, deepcopy(v))
else:
state[k] = deepcopy(v)
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})
def convert_state_dict_to_cpu(state: Dict[str, torch.Tensor]):
return {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in state.items()}
import pytest import pytest
import colossalai import colossalai
import torch import torch
from colossalai.context.parallel_mode import ParallelMode
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.core import global_context as gpc from colossalai.gemini import ChunkManager
from functools import partial from functools import partial
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel.data_parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.testing import parameterize
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
def init_zero(model, use_chunk, use_zero, placement_policy): def check_state(s1, s2):
pg = ProcessGroup() for v1, v2 in zip(s1.values(), s2.values()):
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None if isinstance(v1, torch.Tensor):
chunk_manager = ChunkManager(chunk_size, v1 = v1.to(v2.device)
pg, assert torch.equal(v1, v2), f'{torch.sum((v1-v2).abs())}'
enable_distributed_storage=use_zero, else:
init_device=GeminiManager.get_default_device(placement_policy)) assert v1 == v2
gemini_manager = GeminiManager(placement_policy, chunk_manager)
return ZeroDDP(model, gemini_manager)
def run_step(model, optim, criterion, data, label): def check_load_state_dict(optim, torch_optim):
optim.zero_grad() for group, torch_group in zip(optim.optim.param_groups, torch_optim.param_groups):
logits = model(data) for p, torch_p in zip(group['params'], torch_group['params']):
loss = criterion(logits, label) state = optim.optim.state[p]
optim.backward(loss) torch_state = torch_optim.state[torch_p]
optim.step() if p.storage().size() == 0:
assert len(state) == 0
check_state(state, torch_state)
def check_state_dict_eq(state_dict, other): def check_state_dict(state_dict, torch_state_dict):
for p, state in state_dict['state'].items(): for (k1, s1), (k2, s2) in zip(state_dict['state'].items(), torch_state_dict['state'].items()):
other_state = other['state'][p] assert k1 == k2
for k, v in state.items(): check_state(s1, s2)
if isinstance(v, torch.Tensor):
assert torch.allclose(v, other_state[k], atol=1e-3), f'{v} vs {other_state[k]}'
else:
assert v == other_state[k]
@parameterize('use_chunk', [False, True]) @parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True]) @parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
def run_nested_model(use_chunk, use_zero, placement_policy): def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy):
get_components_func = non_distributed_component_funcs.get_callable('nested_model') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(42)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder() model = model_builder()
set_seed(42) model = model.cuda()
with ColoInitContext(device=get_current_device()): torch_model = model_builder().cuda()
model_copy = model_builder()
model = init_zero(model, use_chunk, use_zero, placement_policy) pg = ProcessGroup()
model_copy = init_zero(model_copy, use_chunk, use_zero, placement_policy)
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3) optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=32) optim = ZeroOptimizer(optim, model, initial_scale=1)
optim_copy = HybridAdam(model_copy.parameters(), lr=1e-3)
optim_copy = ZeroOptimizer(optim_copy, model_copy, initial_scale=32) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
model.train() for p in torch_model.parameters():
model_copy.train() p.grad = torch.rand_like(p)
set_seed(gpc.get_local_rank(ParallelMode.DATA))
data_iter = iter(train_dataloader)
data, label = map(lambda x: x.cuda(), next(data_iter)) torch_optim.step()
run_step(model, optim, criterion, data, label) torch_state_dict = torch_optim.state_dict()
optim_copy.load_state_dict(optim.state_dict()) optim.load_state_dict(torch_state_dict)
check_state_dict_eq(optim.state_dict(), optim_copy.state_dict()) check_load_state_dict(optim, torch_optim)
data, label = map(lambda x: x.cuda(), next(data_iter)) state_dict = optim.state_dict()
run_step(model_copy, optim_copy, criterion, data, label) if pg.rank() == 0:
check_state_dict(state_dict, torch_state_dict)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') config = {}
run_nested_model() colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_zero_optim_state_dict()
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_zero_optim_state_dist(world_size): def test_zero_optim_state_dict(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_optim_state_dist(2) test_zero_optim_state_dict(2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment