Unverified Commit c9625dbb authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540)

* implement sharded optimizer saving

* add more param info

* finish implementation of sharded optimizer saving

* fix bugs in optimizer sharded saving

* add pp+zero test

* param group loading

* greedy loading of optimizer

* fix bug when loading

* implement optimizer sharded saving

* add optimizer test & arrange checkpointIO utils

* fix gemini sharding state_dict

* add verbose option

* add loading of master params

* fix typehint

* fix master/working mapping in fp16 amp
parent 2c787d7f
import random
from contextlib import nullcontext
from functools import partial
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np
import torch
......@@ -110,6 +110,36 @@ class HybridParallelModule(ModelWrapper):
return module
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A complete param_group, with params in the form of param_id
# 2. A mapping from param address (obtained using id(param)) to integer param_id
# 3. A mapping from integer param_id to param address.
# 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
# When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.
if optim is None:
return {}
param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != 'params'}
packed_group['params'] = []
for param_id, param in enumerate(group['params'], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
packed_group['params'].append(param_id)
param_info['param2id'][id(param)] = param_id
param_info['id2param'][param_id] = id(param)
param_info['param2shape'][id(param)] = original_shape
param_info['param_groups'].append(packed_group)
start_index += len(group['params'])
return param_info
def init_pipeline_optimizer(optim: Optimizer, model: Module):
params = set(model.parameters())
new_param_groups = []
......@@ -121,7 +151,8 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim)
......@@ -133,6 +164,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
optim: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
precision: str = 'fp16',
initial_scale: float = 2**16,
min_scale: float = 1,
......@@ -142,6 +174,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
......@@ -155,6 +188,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.,
......@@ -172,6 +206,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
......@@ -356,6 +391,7 @@ class HybridParallelPlugin(PipelinePluginBase):
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
......@@ -366,25 +402,33 @@ class HybridParallelPlugin(PipelinePluginBase):
optimizer = HybridParallelAMPOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
optimizer.master_to_working_map)
else:
optimizer = HybridParallelNaiveOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism)
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
**self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
optimizer._param_store.master_to_working_param)
return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline(self,
......@@ -461,7 +505,8 @@ class HybridParallelPlugin(PipelinePluginBase):
**_kwargs)
def get_checkpoint_io(self) -> CheckpointIO:
return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group)
self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError
This diff is collapsed.
......@@ -679,7 +679,7 @@ class ZeroDDP(ColoDDP):
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param)
block, block_size = sharder.append(prefix + name, gathered_param)
block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None:
yield block, block_size
......@@ -690,7 +690,7 @@ class ZeroDDP(ColoDDP):
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = sharder.append(prefix + name, buffer)
block, block_size = sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
# save extra states
......@@ -698,7 +698,7 @@ class ZeroDDP(ColoDDP):
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = self.get_extra_state()
block, block_size = sharder.append(extra_state_key, extra_state)
block, block_size = sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
......
......@@ -10,7 +10,7 @@ from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.checkpoint_io.utils import StateDictSharder
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.tensor.d_tensor import is_distributed_tensor
......@@ -691,49 +691,17 @@ class ZeroOptimizer(ColossalaiOptimizer):
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
"""
current_block = {}
current_block_size = 0
sharder = StateDictSharder(max_shard_size)
for param_id in self.id_to_real_params.keys():
dist.barrier()
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
ret_block = None
ret_block_size = 0
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if not isinstance(state_tensor, torch.Tensor):
continue
# If the states are stored as DTensors, mark isDTensor as true.
if is_distributed_tensor(state_tensor):
isDTensor = True
state_size += calculate_tensor_size(state_tensor)
if not isDTensor:
if current_block_size + state_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
current_block[param_id] = state
current_block_size += state_size
if ret_block != None:
yield ret_block, ret_block_size
block, block_size = sharder.append_optim_state(param_id, state)
if block is not None:
yield block, block_size
yield current_block, current_block_size
yield sharder.current_block, sharder.current_block_size
class GeminiAdamOptimizer(ZeroOptimizer):
......
......@@ -10,6 +10,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import (
assert_close_loose,
check_state_dict_equal,
clear_cache_before_run,
parameterize,
......@@ -19,34 +20,34 @@ from colossalai.testing import (
from tests.kit.model_zoo import model_zoo
# TODO (Baizhou): Add test cases for shard=False
@clear_cache_before_run()
@parameterize('shard', [True])
@parameterize('model_name', ['transformers_gpt'])
@parameterize('size_per_shard', [32])
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'precision': 'fp32',
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 1,
'precision': 'fp32',
'pp_size': 2,
'num_microbatches': 4,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 2,
'pp_size': 1,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
......@@ -61,46 +62,91 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
loss = criterion(outputs)
return loss
def _preprocess_data(data):
if booster.plugin.stage_manager is not None:
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)
return iter([data])
else:
return {k: v.cuda() for k, v in data.items()}
model = model_fn().cuda()
optimizer = Adam(model.parameters(), lr=1e-3)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
new_model = model_fn().cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
data = data_gen_fn()
model.train()
if booster.plugin.stage_manager is not None:
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)
data_iter = iter([data])
output = booster.execute_pipeline(data_iter,
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=False)
booster.execute_pipeline(_preprocess_data(data),
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=False)
else:
data = {k: v.cuda() for k, v in data.items()}
output = model(**data)
output = model(**_preprocess_data(data))
loss = criterion(output)
optimizer.backward(loss)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
# optimizer_ckpt_path = f"{tempdir}/optimizer"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
# booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()
new_model = model_fn().cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False)
dist.barrier()
# Check whether the loaded model & optimizer works smoothly.
model.train()
new_model.train()
if booster.plugin.stage_manager is not None:
booster.execute_pipeline(_preprocess_data(data),
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=False)
booster.execute_pipeline(_preprocess_data(data),
new_model,
_criterion,
new_optimizer,
return_loss=True,
return_outputs=False)
else:
old_model_loss = criterion(model(**_preprocess_data(data)))
optimizer.backward(old_model_loss)
new_model_loss = criterion(new_model(**_preprocess_data(data)))
new_optimizer.backward(new_model_loss)
optimizer.step()
new_optimizer.step()
# Check updated weights.
stage_manager = booster.plugin.stage_manager
if stage_manager is None or stage_manager.is_first_stage():
assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3)
assert_close_loose(model.unwrap().h[0].mlp.c_fc.weight.data,
new_model.unwrap().h[0].mlp.c_fc.weight.data,
atol=5e-3,
rtol=5e-3)
dist.barrier()
Randomizer.reset_index()
clear_layout_converter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment