Unverified Commit 807e01a4 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[zero] hotfix master param sync (#4618)

* [zero] add method to update master params

* [zero] update zero plugin

* [plugin] update low level zero plugin
parent aaeb520c
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import warnings import warnings
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch import torch
...@@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import ( ...@@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue, sharded_optimizer_loading_epilogue,
unwrap_optimizer, unwrap_optimizer,
) )
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO from .torch_ddp_plugin import TorchDDPCheckpointIO
...@@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): ...@@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
self.dtype = None
if precision == 'fp16':
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def unwrap(self):
# TODO(ver217): this is a workaround for loading model
return self
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
...@@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): ...@@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
sharded_optimizer_loading_epilogue(optimizer) sharded_optimizer_loading_epilogue(optimizer)
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
class LowLevelZeroModel(ModelWrapper): use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel)
def __init__(self, module: nn.Module, stage: int, precision: str) -> None: super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
super().__init__(module)
self.dtype = None def save_sharded_model(self,
if precision == 'fp16': model: nn.Module,
self.dtype = torch.float16 checkpoint_path: str,
elif precision == 'bf16': gather_dtensor: bool = True,
self.dtype = torch.bfloat16 prefix: Optional[str] = None,
module = zero_model_wrapper(module, zero_stage=stage) max_shard_size: int = 1024,
if self.dtype is not None: use_safetensors: bool = False):
module = module.to(self.dtype) assert isinstance(model, LowLevelZeroModel)
module = module.to(get_current_device()) super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
self.module = module use_safetensors)
self.convert_fn = None
if self.dtype is not None: def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) assert isinstance(model, LowLevelZeroModel)
super().load_unsharded_model(model.module, checkpoint, strict)
def forward(self, *args, **kwargs): model.update_master_params()
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args) def load_sharded_model(self,
kwargs = tree_map(self.convert_fn, kwargs) model: LowLevelZeroModel,
return super().forward(*args, **kwargs) checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
class LowLevelZeroPlugin(DPPluginBase): class LowLevelZeroPlugin(DPPluginBase):
...@@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase): ...@@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase):
super().__init__() super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
self.stage = stage self.stage = stage
self.precision = precision self.precision = precision
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, self.zero_optim_kwargs = dict(
communication_dtype=communication_dtype, initial_scale=initial_scale,
overlap_communication=overlap_communication, growth_factor=growth_factor,
cpu_offload=cpu_offload) backoff_factor=backoff_factor,
self.optim_kwargs = dict(initial_scale=initial_scale, growth_interval=growth_interval,
growth_factor=growth_factor, hysteresis=hysteresis,
backoff_factor=backoff_factor, min_scale=min_scale,
growth_interval=growth_interval, max_scale=max_scale,
hysteresis=hysteresis, clip_grad_norm=max_norm,
min_scale=min_scale, reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
max_scale=max_scale, communication_dtype=communication_dtype,
max_norm=max_norm, overlap_communication=overlap_communication,
norm_type=norm_type) cpu_offload=cpu_offload,
partition_grad=(stage == 2),
)
self.verbose = verbose self.verbose = verbose
# set class name with stage, for better error message # set class name with stage, for better error message
...@@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase): ...@@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase):
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision) model = LowLevelZeroModel(model, self.precision)
if optimizer is not None and \ if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper): not isinstance(optimizer, OptimizerWrapper):
optimizer = zero_optim_wrapper(model.unwrap(), optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
optimizer, **self.zero_optim_kwargs,
optim_config=self.zero_optim_config, verbose=self.verbose)
**self.optim_kwargs, # inject update_master_params
verbose=self.verbose) model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
......
from .model import ModelWrapper from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper from .optimizer import OptimizerWrapper
__all__ = ['OptimizerWrapper', 'ModelWrapper'] __all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']
...@@ -23,3 +23,14 @@ class ModelWrapper(nn.Module): ...@@ -23,3 +23,14 @@ class ModelWrapper(nn.Module):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.module(*args, **kwargs) return self.module(*args, **kwargs)
class AMPModelMixin:
"""This mixin class defines the interface for AMP training.
"""
def update_master_params(self):
"""
Update the master parameters for AMP training.
"""
pass
...@@ -6,6 +6,7 @@ from typing import Dict, Iterator, Optional, Tuple ...@@ -6,6 +6,7 @@ from typing import Dict, Iterator, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim import Optimizer from torch.optim import Optimizer
...@@ -600,3 +601,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -600,3 +601,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
ret_block_size += current_block_size ret_block_size += current_block_size
yield ret_block, ret_block_size yield ret_block, ret_block_size
def update_master_params(self, model: nn.Module) -> None:
"""Update master params from working params
Args:
model (nn.Module): The model to update master params
"""
for p in model.parameters():
p_id = id(p)
if p_id in self._param_store.working_to_master_param:
master_param = self._param_store.working_to_master_param[p_id]
padding_size = self._param_store.get_param_padding_size(p)
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
...@@ -14,6 +14,7 @@ from colossalai.testing import ( ...@@ -14,6 +14,7 @@ from colossalai.testing import (
rerun_if_address_is_in_use, rerun_if_address_is_in_use,
spawn, spawn,
) )
from colossalai.zero import LowLevelZeroOptimizer
# stage 1 and 2 process the optimizer/mode the same way # stage 1 and 2 process the optimizer/mode the same way
...@@ -50,6 +51,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): ...@@ -50,6 +51,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
booster.load_model(new_model, model_ckpt_path) booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
working_param_id_set = set(id(p) for p in new_model.parameters())
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
assert p_id in working_param_id_set
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
padding = new_optimizer._param_store.get_param_padding_size(working_param)
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
assert torch.equal(working_shard,
master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device))
booster.load_optimizer(new_optimizer, optimizer_ckpt_path) booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment