Unverified Commit c9625dbb authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540)

* implement sharded optimizer saving

* add more param info

* finish implementation of sharded optimizer saving

* fix bugs in optimizer sharded saving

* add pp+zero test

* param group loading

* greedy loading of optimizer

* fix bug when loading

* implement optimizer sharded saving

* add optimizer test & arrange checkpointIO utils

* fix gemini sharding state_dict

* add verbose option

* add loading of master params

* fix typehint

* fix master/working mapping in fp16 amp
parent 2c787d7f
import random import random
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial from functools import partial
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -110,6 +110,36 @@ class HybridParallelModule(ModelWrapper): ...@@ -110,6 +110,36 @@ class HybridParallelModule(ModelWrapper):
return module return module
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A complete param_group, with params in the form of param_id
# 2. A mapping from param address (obtained using id(param)) to integer param_id
# 3. A mapping from integer param_id to param address.
# 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
# When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.
if optim is None:
return {}
param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != 'params'}
packed_group['params'] = []
for param_id, param in enumerate(group['params'], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
packed_group['params'].append(param_id)
param_info['param2id'][id(param)] = param_id
param_info['id2param'][param_id] = id(param)
param_info['param2shape'][id(param)] = original_shape
param_info['param_groups'].append(packed_group)
start_index += len(group['params'])
return param_info
def init_pipeline_optimizer(optim: Optimizer, model: Module): def init_pipeline_optimizer(optim: Optimizer, model: Module):
params = set(model.parameters()) params = set(model.parameters())
new_param_groups = [] new_param_groups = []
...@@ -121,7 +151,8 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): ...@@ -121,7 +151,8 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
class HybridParallelNaiveOptimizer(OptimizerWrapper): class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool): def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
self.param_info = param_info
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optim, model) init_pipeline_optimizer(optim, model)
super().__init__(optim) super().__init__(optim)
...@@ -133,6 +164,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): ...@@ -133,6 +164,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
optim: Optimizer, optim: Optimizer,
model: Module, model: Module,
use_pipeline: bool, use_pipeline: bool,
param_info: OrderedDict,
precision: str = 'fp16', precision: str = 'fp16',
initial_scale: float = 2**16, initial_scale: float = 2**16,
min_scale: float = 1, min_scale: float = 1,
...@@ -142,6 +174,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): ...@@ -142,6 +174,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
hysteresis: int = 2, hysteresis: int = 2,
max_scale: float = 2**32, max_scale: float = 2**32,
max_norm: float = 0): max_norm: float = 0):
self.param_info = param_info
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optim, model) init_pipeline_optimizer(optim, model)
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
...@@ -155,6 +188,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): ...@@ -155,6 +188,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer: Optimizer, optimizer: Optimizer,
model: Module, model: Module,
use_pipeline: bool, use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1, min_scale: int = 1,
growth_factor: float = 2., growth_factor: float = 2.,
...@@ -172,6 +206,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): ...@@ -172,6 +206,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None): forced_dtype: Optional[torch.dtype] = None):
self.param_info = param_info
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optimizer, model) init_pipeline_optimizer(optimizer, model)
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
...@@ -356,6 +391,7 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -356,6 +391,7 @@ class HybridParallelPlugin(PipelinePluginBase):
dataloader: Optional[DataLoader] = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
...@@ -366,25 +402,33 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -366,25 +402,33 @@ class HybridParallelPlugin(PipelinePluginBase):
optimizer = HybridParallelAMPOptimizer(optimizer, optimizer = HybridParallelAMPOptimizer(optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision, precision=self.precision,
max_norm=self.max_norm, max_norm=self.max_norm,
**self.amp_config) **self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
optimizer.master_to_working_map)
else: else:
optimizer = HybridParallelNaiveOptimizer(optimizer, optimizer = HybridParallelNaiveOptimizer(optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism) use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info)
else: else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(optimizer, optimizer = HybridParallelZeroOptimizer(optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.dp_group, dp_process_group=self.dp_group,
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
verbose=True, verbose=True,
clip_grad_norm=self.max_norm, clip_grad_norm=self.max_norm,
**self.zero_config, **self.zero_config,
**self.amp_config) **self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
optimizer._param_store.master_to_working_param)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline(self, def execute_pipeline(self,
...@@ -461,7 +505,8 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -461,7 +505,8 @@ class HybridParallelPlugin(PipelinePluginBase):
**_kwargs) **_kwargs)
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group) self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def no_sync(self, model: Module) -> Iterator[None]: def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError raise NotImplementedError
This diff is collapsed.
...@@ -679,7 +679,7 @@ class ZeroDDP(ColoDDP): ...@@ -679,7 +679,7 @@ class ZeroDDP(ColoDDP):
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param) gathered_param = gathered_param_buffer.pop(fp32_param)
block, block_size = sharder.append(prefix + name, gathered_param) block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None: if block is not None:
yield block, block_size yield block, block_size
...@@ -690,7 +690,7 @@ class ZeroDDP(ColoDDP): ...@@ -690,7 +690,7 @@ class ZeroDDP(ColoDDP):
for name, buf in self.named_buffers(): for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set: if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach() buffer = buf if keep_vars else buf.detach()
block, block_size = sharder.append(prefix + name, buffer) block, block_size = sharder.append_param(prefix + name, buffer)
if block is not None: if block is not None:
yield block, block_size yield block, block_size
# save extra states # save extra states
...@@ -698,7 +698,7 @@ class ZeroDDP(ColoDDP): ...@@ -698,7 +698,7 @@ class ZeroDDP(ColoDDP):
if getattr(self.__class__, "get_extra_state", if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = self.get_extra_state() extra_state = self.get_extra_state()
block, block_size = sharder.append(extra_state_key, extra_state) block, block_size = sharder.append_param(extra_state_key, extra_state)
if block is not None: if block is not None:
yield block, block_size yield block, block_size
......
...@@ -10,7 +10,7 @@ from torch.nn import Parameter ...@@ -10,7 +10,7 @@ from torch.nn import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import calculate_tensor_size from colossalai.checkpoint_io.utils import StateDictSharder
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.tensor.d_tensor import is_distributed_tensor
...@@ -691,49 +691,17 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -691,49 +691,17 @@ class ZeroOptimizer(ColossalaiOptimizer):
Iterator[OrderedDict]: A generator of state dict shard of optimizer states. Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
""" """
current_block = {} sharder = StateDictSharder(max_shard_size)
current_block_size = 0
for param_id in self.id_to_real_params.keys(): for param_id in self.id_to_real_params.keys():
dist.barrier() dist.barrier()
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
ret_block = None block, block_size = sharder.append_optim_state(param_id, state)
ret_block_size = 0 if block is not None:
yield block, block_size
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if not isinstance(state_tensor, torch.Tensor):
continue
# If the states are stored as DTensors, mark isDTensor as true.
if is_distributed_tensor(state_tensor):
isDTensor = True
state_size += calculate_tensor_size(state_tensor)
if not isDTensor:
if current_block_size + state_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
current_block[param_id] = state
current_block_size += state_size
if ret_block != None:
yield ret_block, ret_block_size
yield current_block, current_block_size yield sharder.current_block, sharder.current_block_size
class GeminiAdamOptimizer(ZeroOptimizer): class GeminiAdamOptimizer(ZeroOptimizer):
......
...@@ -10,6 +10,7 @@ from colossalai.booster.plugin import HybridParallelPlugin ...@@ -10,6 +10,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.shardformer.layer.utils import Randomizer from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import ( from colossalai.testing import (
assert_close_loose,
check_state_dict_equal, check_state_dict_equal,
clear_cache_before_run, clear_cache_before_run,
parameterize, parameterize,
...@@ -19,34 +20,34 @@ from colossalai.testing import ( ...@@ -19,34 +20,34 @@ from colossalai.testing import (
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
# TODO (Baizhou): Add test cases for shard=False
@clear_cache_before_run() @clear_cache_before_run()
@parameterize('shard', [True]) @parameterize('shard', [True])
@parameterize('model_name', ['transformers_gpt']) @parameterize('model_name', ['transformers_gpt'])
@parameterize('size_per_shard', [32]) @parameterize('size_per_shard', [32])
@parameterize('test_config', [{ @parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'precision': 'fp32',
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'precision': 'fp32',
}, {
'tp_size': 4, 'tp_size': 4,
'pp_size': 1, 'pp_size': 1,
'precision': 'fp32', 'precision': 'fp32',
}, { }, {
'tp_size': 2, 'tp_size': 2,
'pp_size': 1, 'pp_size': 2,
'precision': 'fp32', 'num_microbatches': 4,
'precision': 'fp16',
'initial_scale': 1
}, { }, {
'tp_size': 2, 'tp_size': 2,
'pp_size': 1, 'pp_size': 1,
'zero_stage': 2, 'zero_stage': 2,
'precision': 'fp16', 'precision': 'fp16',
'initial_scale': 1 'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}]) }])
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
...@@ -61,46 +62,91 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf ...@@ -61,46 +62,91 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
loss = criterion(outputs) loss = criterion(outputs)
return loss return loss
def _preprocess_data(data):
if booster.plugin.stage_manager is not None:
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)
return iter([data])
else:
return {k: v.cuda() for k, v in data.items()}
model = model_fn().cuda() model = model_fn().cuda()
optimizer = Adam(model.parameters(), lr=1e-3) optimizer = Adam(model.parameters(), lr=1e-3)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
new_model = model_fn().cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
data = data_gen_fn() data = data_gen_fn()
model.train() model.train()
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None:
for k, v in data.items(): booster.execute_pipeline(_preprocess_data(data),
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: model,
new_shape = [1] * v.dim() _criterion,
new_shape[0] = 4 optimizer,
data[k] = v.to('cuda').repeat(*new_shape) return_loss=True,
data_iter = iter([data]) return_outputs=False)
output = booster.execute_pipeline(data_iter,
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=False)
else: else:
data = {k: v.cuda() for k, v in data.items()} output = model(**_preprocess_data(data))
output = model(**data)
loss = criterion(output) loss = criterion(output)
optimizer.backward(loss) optimizer.backward(loss)
optimizer.step() optimizer.step()
with shared_tempdir() as tempdir: with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model" model_ckpt_path = f"{tempdir}/model"
# optimizer_ckpt_path = f"{tempdir}/optimizer" optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
# booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier() dist.barrier()
new_model = model_fn().cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
booster.load_model(new_model, model_ckpt_path) booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False)
dist.barrier()
# Check whether the loaded model & optimizer works smoothly.
model.train()
new_model.train()
if booster.plugin.stage_manager is not None:
booster.execute_pipeline(_preprocess_data(data),
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=False)
booster.execute_pipeline(_preprocess_data(data),
new_model,
_criterion,
new_optimizer,
return_loss=True,
return_outputs=False)
else:
old_model_loss = criterion(model(**_preprocess_data(data)))
optimizer.backward(old_model_loss)
new_model_loss = criterion(new_model(**_preprocess_data(data)))
new_optimizer.backward(new_model_loss)
optimizer.step()
new_optimizer.step()
# Check updated weights.
stage_manager = booster.plugin.stage_manager
if stage_manager is None or stage_manager.is_first_stage():
assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3)
assert_close_loose(model.unwrap().h[0].mlp.c_fc.weight.data,
new_model.unwrap().h[0].mlp.c_fc.weight.data,
atol=5e-3,
rtol=5e-3)
dist.barrier()
Randomizer.reset_index() Randomizer.reset_index()
clear_layout_converter() clear_layout_converter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment