Unverified Commit 7752dc5e authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Fix layout bug in ZeRO Stage 1 checkpoint logic (#531)

* Fix layout bug in ZeRO Stage 1 checkpoint logic
Add elastic checkpoint option for ZeRO stage 1, default to True

* Format fixes
parent 9941ce75
......@@ -347,6 +347,9 @@ class DeepSpeedEngine(Module):
def zero_load_from_fp32_weights(self):
return self._config.zero_config.load_from_fp32_weights
def zero_elastic_checkpoint(self):
return self._config.zero_config.elastic_checkpoint
def fp16_enabled(self):
return self._config.fp16_enabled
......@@ -669,6 +672,7 @@ class DeepSpeedEngine(Module):
allgather_size=self.zero_allgather_bucket_size(),
max_elements_per_comm=self.zero_reduce_bucket_size(),
dp_process_group=self.data_parallel_group,
elastic_checkpoint=self.zero_elastic_checkpoint(),
mpu=self.mpu)
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
optimizer = FP16_DeepSpeedZeroOptimizer(
......
......@@ -21,6 +21,7 @@ class DeepSpeedZeroConfig(object):
self.overlap_comm = None
self.load_from_fp32_weights = None
self.cpu_offload = None
self.elastic_checkpoint = None
if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
......@@ -94,3 +95,8 @@ class DeepSpeedZeroConfig(object):
self.cpu_offload = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_CPU_OFFLOAD,
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT)
self.elastic_checkpoint = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT)
......@@ -63,6 +63,9 @@ ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True
ZERO_OPTIMIZATION_CPU_OFFLOAD = 'cpu_offload'
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT = 'elastic_checkpoint'
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT = True
ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
......@@ -75,5 +78,6 @@ ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT
ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT: ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT
}
......@@ -123,7 +123,8 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
all_gather_partitions=True,
allgather_size=500000000,
clip_grad=0.0,
max_elements_per_comm=5e8):
max_elements_per_comm=5e8,
elastic_checkpoint=True):
if dp_process_group is not None and partition_size is not None:
raise ValueError("Cannot specify both dp_process_group "
......@@ -146,6 +147,9 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
self.max_elements_per_comm = max_elements_per_comm
logger.info("max_elements_per_comm={}".format(max_elements_per_comm))
self.elastic_checkpoint = elastic_checkpoint
logger.info(f'ZeRO Elastic Checkpoint = {elastic_checkpoint}')
# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
......@@ -757,18 +761,30 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
# Return communication interval paddings for local rank and group
def _get_local_group_paddings(self, group_index):
local_rank = dist.get_rank(group=self.dp_process_group)
sub_partition_indices = [
local_rank + (comm_idx * self.partition_count)
for comm_idx in range(self.num_comm_intervals_per_group[group_index])
]
group_paddings = [
self.group_paddings[group_index][sub_idx]
for sub_idx in sub_partition_indices
]
return group_paddings
# Return group tensor after removing paddings that are added for alignment to DP world size.
# This method works on the assumption that each group contains sub partitions.
def _get_groups_without_padding(self, groups_with_padding):
groups_without_padding = []
local_rank = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(groups_with_padding):
low_index = local_rank * len(group)
high_index = (local_rank + 1) * len(group)
group_paddings = self.group_paddings[i][low_index:high_index]
for group_index, group in enumerate(groups_with_padding):
group_paddings = self._get_local_group_paddings(group_index)
lean_sub_partitions = []
for j, sub_partition in enumerate(group):
lean_length = sub_partition.numel() - group_paddings[j]
for sub_partition, padding in zip(group, group_paddings):
lean_length = sub_partition.numel() - padding
lean_sub_partitions.append(sub_partition[:lean_length])
groups_without_padding.append(lean_sub_partitions)
......@@ -790,12 +806,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
# This method assumes that each param group contains a single flattened tensor.
def _get_base_optimizer_state(self):
optimizer_groups_state = []
local_rank = dist.get_rank(group=self.dp_process_group)
for group_idx, group in enumerate(self.optimizer.param_groups):
for group_index, group in enumerate(self.optimizer.param_groups):
param_paddings = self._get_local_group_paddings(group_index)
group_lean_state = []
low_index = local_rank * self.num_comm_intervals_per_group[group_idx]
high_index = (local_rank + 1) * self.num_comm_intervals_per_group[group_idx]
param_paddings = self.group_paddings[group_idx][low_index:high_index]
for param_idx, param in enumerate(group['params']):
lean_state = self._get_state_without_padding(self.optimizer.state[param],
param_paddings[param_idx])
......@@ -805,7 +820,10 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
return optimizer_groups_state
def state_dict(self):
def _rigid_state_dict(self):
"""
Returns a dict that can be loaded for continued training with same DP degree
"""
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
......@@ -820,6 +838,19 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['base_optimizer_state'] = self.optimizer.state_dict()
state_dict[
'local_sub_partitions_of_fp32_groups'] = self.local_sub_partitions_of_fp32_groups
return state_dict
def _elastic_state_dict(self):
"""
Returns a dict that can be loaded for elastic training with different DP degree
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES
......@@ -833,13 +864,40 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
return state_dict
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
if self.elastic_checkpoint:
return self._elastic_state_dict()
return self._rigid_state_dict()
# Extract the fp32 weights of the current rank from checkpoint by merging the
# sub partitions of communication intervals across ranks.
# Let sub_i_j = sub partition of rank i and comm interval j
# For 2 ranks and 2 comm intervals, checkpoints (minus padding) are as follows:
# rank 0 = [sub_0_0, sub_0_1]
# rank 1 = [sub_1_0, sub_1_1]
# Merge to get [sub_0_0, sub_1_0, sub_0_1, sub_1_1] => original un-padded flattened tensor.
def _retrieve_group_sub_partition_weights(self, all_partition_fp32_weights):
partition_id = dist.get_rank(group=self.dp_process_group)
num_partitions = len(all_partition_fp32_weights)
num_comm_intervals = len(all_partition_fp32_weights[0])
num_sub_partitions = num_partitions * num_comm_intervals
all_sub_partition_weights = [None] * num_sub_partitions
all_sub_partition_weights = []
for partition_weights in all_partition_fp32_weights:
for sub_partition_weights in partition_weights:
all_sub_partition_weights.append(sub_partition_weights)
for rank, partition_weights in enumerate(all_partition_fp32_weights):
for comm_idx, sub_partition_weights in enumerate(partition_weights):
#all_sub_partition_weights.append(sub_partition_weights)
sub_partition_idx = (comm_idx * num_partitions) + rank
all_sub_partition_weights[sub_partition_idx] = sub_partition_weights
flat_merged_weights = flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_sub_partition_weights,
......@@ -855,6 +913,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
dp_process_group=self.dp_process_group
)
partition_id = dist.get_rank(group=self.dp_process_group)
return [sub_partition for sub_partition in dp_sub_partitions[partition_id]]
# Restore base optimizer fp32 weights from checkpoint by:
......@@ -903,13 +962,18 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
# 3) Return state corresponding to local partition
def _retrieve_group_optimizer_states(self, all_partition_states):
merged_optimizer_states = {}
for partition_state in all_partition_states:
for sub_partition_state in partition_state:
num_partitions = len(all_partition_states)
num_comm_intervals = len(all_partition_states[0])
num_sub_partitions = num_partitions * num_comm_intervals
for rank, partition_state in enumerate(all_partition_states):
for comm_idx, sub_partition_state in enumerate(partition_state):
for key, value in sub_partition_state.items():
if not key in merged_optimizer_states.keys():
merged_optimizer_states[key] = [value]
else:
merged_optimizer_states[key].append(value)
merged_optimizer_states[key] = [None] * num_sub_partitions
sub_partition_idx = (comm_idx * num_partitions) + rank
merged_optimizer_states[key][sub_partition_idx] = value
group_optimizer_states = {}
for key, value in merged_optimizer_states.items():
......@@ -950,10 +1014,23 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
def refresh_fp32_params(self):
self._restore_from_fp16_weights()
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['base_optimizer_state'])
for curr_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, state_dict['local_sub_partitions_of_fp32_groups']):
for curr_param, saved_param in zip(curr_group, saved_group):
curr_param.data.copy_(saved_param.data)
def _elastic_load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
......@@ -981,3 +1058,46 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
self._restore_from_fp32_weights(state_dict_list)
else:
self._restore_from_fp16_weights()
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
if self.elastic_checkpoint:
self._elastic_load_state_dict(state_dict_list,
load_optimizer_states,
load_from_fp32_weights)
else:
self._rigid_load_state_dict(
state_dict_list[dist.get_rank(group=self.dp_process_group)],
load_optimizer_states)
def _dump_optimizer_state(self, message):
logger.info(f'{message}')
for i, group in enumerate(self.optimizer.param_groups):
for j, param in enumerate(group['params']):
for key, value in self.optimizer.state[param].items():
t_stats = [
value.min(),
value.max(),
(value.max() - value.min()),
value.mean()
]
stats = [float(t) for t in t_stats]
logger.info(
f'group/param/key/min/max/delta/mean = {i}, {j}, {key}: {stats}')
......@@ -256,19 +256,16 @@ def test_checkpoint_fused_optimizer(tmpdir):
load_optimizer_states=False)
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
@pytest.mark.parametrize('zero_stage, use_cpu_offload',
[
(1,
False,
'Adam'),
False),
(2,
False,
'Adam'),
False),
(2,
True,
'Adam'),
True),
])
def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
......@@ -276,7 +273,7 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": adam_optimizer,
"type": 'Adam',
"params": {
"lr": 0.00015,
"betas": [0.8,
......@@ -312,22 +309,16 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
load_optimizer_states=True)
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
@pytest.mark.parametrize('zero_stage, use_cpu_offload',
[
(1,
False,
"Adam"),
False),
(2,
False,
"Adam"),
False),
(2,
True,
'Adam'),
True),
])
def test_checkpoint_zero_no_optimizer(tmpdir,
zero_stage,
use_cpu_offload,
adam_optimizer):
def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
......@@ -335,7 +326,7 @@ def test_checkpoint_zero_no_optimizer(tmpdir,
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": adam_optimizer,
"type": 'Adam',
"params": {
"lr": 0.00015,
"betas": [0.8,
......@@ -374,22 +365,18 @@ def test_checkpoint_zero_no_optimizer(tmpdir,
load_optimizer_states=False)
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
@pytest.mark.parametrize('zero_stage, use_cpu_offload',
[
(0,
False,
'Adam'),
False),
(1,
False,
'Adam'),
False),
(2,
False,
'Adam'),
False),
(2,
True,
'Adam'),
True),
])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
......@@ -397,7 +384,7 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": adam_optimizer,
"type": 'Adam',
"params": {
"lr": 0.00015,
"betas": [0.8,
......@@ -448,22 +435,18 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim
load_lr_scheduler_states=True)
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
@pytest.mark.parametrize('zero_stage, use_cpu_offload',
[
(0,
False,
'Adam'),
False),
(1,
False,
'Adam'),
False),
(2,
False,
'Adam'),
False),
(2,
True,
'Adam'),
True),
])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
......@@ -471,7 +454,7 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_op
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": adam_optimizer,
"type": 'Adam',
"params": {
"lr": 1e-5
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment