"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "74f01405337f861c38c55d9a2e528600583ae30b"
Unverified Commit 7752dc5e authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Fix layout bug in ZeRO Stage 1 checkpoint logic (#531)

* Fix layout bug in ZeRO Stage 1 checkpoint logic
Add elastic checkpoint option for ZeRO stage 1, default to True

* Format fixes
parent 9941ce75
...@@ -347,6 +347,9 @@ class DeepSpeedEngine(Module): ...@@ -347,6 +347,9 @@ class DeepSpeedEngine(Module):
def zero_load_from_fp32_weights(self): def zero_load_from_fp32_weights(self):
return self._config.zero_config.load_from_fp32_weights return self._config.zero_config.load_from_fp32_weights
def zero_elastic_checkpoint(self):
return self._config.zero_config.elastic_checkpoint
def fp16_enabled(self): def fp16_enabled(self):
return self._config.fp16_enabled return self._config.fp16_enabled
...@@ -669,6 +672,7 @@ class DeepSpeedEngine(Module): ...@@ -669,6 +672,7 @@ class DeepSpeedEngine(Module):
allgather_size=self.zero_allgather_bucket_size(), allgather_size=self.zero_allgather_bucket_size(),
max_elements_per_comm=self.zero_reduce_bucket_size(), max_elements_per_comm=self.zero_reduce_bucket_size(),
dp_process_group=self.data_parallel_group, dp_process_group=self.data_parallel_group,
elastic_checkpoint=self.zero_elastic_checkpoint(),
mpu=self.mpu) mpu=self.mpu)
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS: elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
optimizer = FP16_DeepSpeedZeroOptimizer( optimizer = FP16_DeepSpeedZeroOptimizer(
......
...@@ -21,6 +21,7 @@ class DeepSpeedZeroConfig(object): ...@@ -21,6 +21,7 @@ class DeepSpeedZeroConfig(object):
self.overlap_comm = None self.overlap_comm = None
self.load_from_fp32_weights = None self.load_from_fp32_weights = None
self.cpu_offload = None self.cpu_offload = None
self.elastic_checkpoint = None
if ZERO_OPTIMIZATION in param_dict.keys(): if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION] zero_config_dict = param_dict[ZERO_OPTIMIZATION]
...@@ -94,3 +95,8 @@ class DeepSpeedZeroConfig(object): ...@@ -94,3 +95,8 @@ class DeepSpeedZeroConfig(object):
self.cpu_offload = get_scalar_param(zero_config_dict, self.cpu_offload = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_CPU_OFFLOAD, ZERO_OPTIMIZATION_CPU_OFFLOAD,
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT) ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT)
self.elastic_checkpoint = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT)
...@@ -63,6 +63,9 @@ ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True ...@@ -63,6 +63,9 @@ ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True
ZERO_OPTIMIZATION_CPU_OFFLOAD = 'cpu_offload' ZERO_OPTIMIZATION_CPU_OFFLOAD = 'cpu_offload'
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT = 'elastic_checkpoint'
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT = True
ZERO_OPTIMIZATION_DEFAULT = { ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT, ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS: ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
...@@ -75,5 +78,6 @@ ZERO_OPTIMIZATION_DEFAULT = { ...@@ -75,5 +78,6 @@ ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS: ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT,
ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT: ZERO_OPTIMIZATION_ELASTIC_CHECKPOINT_DEFAULT
} }
...@@ -123,7 +123,8 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -123,7 +123,8 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
all_gather_partitions=True, all_gather_partitions=True,
allgather_size=500000000, allgather_size=500000000,
clip_grad=0.0, clip_grad=0.0,
max_elements_per_comm=5e8): max_elements_per_comm=5e8,
elastic_checkpoint=True):
if dp_process_group is not None and partition_size is not None: if dp_process_group is not None and partition_size is not None:
raise ValueError("Cannot specify both dp_process_group " raise ValueError("Cannot specify both dp_process_group "
...@@ -146,6 +147,9 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -146,6 +147,9 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
self.max_elements_per_comm = max_elements_per_comm self.max_elements_per_comm = max_elements_per_comm
logger.info("max_elements_per_comm={}".format(max_elements_per_comm)) logger.info("max_elements_per_comm={}".format(max_elements_per_comm))
self.elastic_checkpoint = elastic_checkpoint
logger.info(f'ZeRO Elastic Checkpoint = {elastic_checkpoint}')
# param flattened by groups # param flattened by groups
self.fp16_groups = [] self.fp16_groups = []
self.fp16_groups_flat = [] self.fp16_groups_flat = []
...@@ -757,18 +761,30 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -757,18 +761,30 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
loss_scale = property(_get_loss_scale, _set_loss_scale) loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale) cur_scale = property(_get_loss_scale, _set_loss_scale)
# Return communication interval paddings for local rank and group
def _get_local_group_paddings(self, group_index):
local_rank = dist.get_rank(group=self.dp_process_group)
sub_partition_indices = [
local_rank + (comm_idx * self.partition_count)
for comm_idx in range(self.num_comm_intervals_per_group[group_index])
]
group_paddings = [
self.group_paddings[group_index][sub_idx]
for sub_idx in sub_partition_indices
]
return group_paddings
# Return group tensor after removing paddings that are added for alignment to DP world size. # Return group tensor after removing paddings that are added for alignment to DP world size.
# This method works on the assumption that each group contains sub partitions. # This method works on the assumption that each group contains sub partitions.
def _get_groups_without_padding(self, groups_with_padding): def _get_groups_without_padding(self, groups_with_padding):
groups_without_padding = [] groups_without_padding = []
local_rank = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(groups_with_padding): for group_index, group in enumerate(groups_with_padding):
low_index = local_rank * len(group) group_paddings = self._get_local_group_paddings(group_index)
high_index = (local_rank + 1) * len(group)
group_paddings = self.group_paddings[i][low_index:high_index]
lean_sub_partitions = [] lean_sub_partitions = []
for j, sub_partition in enumerate(group): for sub_partition, padding in zip(group, group_paddings):
lean_length = sub_partition.numel() - group_paddings[j] lean_length = sub_partition.numel() - padding
lean_sub_partitions.append(sub_partition[:lean_length]) lean_sub_partitions.append(sub_partition[:lean_length])
groups_without_padding.append(lean_sub_partitions) groups_without_padding.append(lean_sub_partitions)
...@@ -790,12 +806,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -790,12 +806,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
# This method assumes that each param group contains a single flattened tensor. # This method assumes that each param group contains a single flattened tensor.
def _get_base_optimizer_state(self): def _get_base_optimizer_state(self):
optimizer_groups_state = [] optimizer_groups_state = []
local_rank = dist.get_rank(group=self.dp_process_group)
for group_idx, group in enumerate(self.optimizer.param_groups): for group_index, group in enumerate(self.optimizer.param_groups):
param_paddings = self._get_local_group_paddings(group_index)
group_lean_state = [] group_lean_state = []
low_index = local_rank * self.num_comm_intervals_per_group[group_idx]
high_index = (local_rank + 1) * self.num_comm_intervals_per_group[group_idx]
param_paddings = self.group_paddings[group_idx][low_index:high_index]
for param_idx, param in enumerate(group['params']): for param_idx, param in enumerate(group['params']):
lean_state = self._get_state_without_padding(self.optimizer.state[param], lean_state = self._get_state_without_padding(self.optimizer.state[param],
param_paddings[param_idx]) param_paddings[param_idx])
...@@ -805,7 +820,10 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -805,7 +820,10 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
return optimizer_groups_state return optimizer_groups_state
def state_dict(self): def _rigid_state_dict(self):
"""
Returns a dict that can be loaded for continued training with same DP degree
"""
""" """
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
...@@ -820,6 +838,19 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -820,6 +838,19 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
state_dict['loss_scaler'] = self.loss_scaler state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow state_dict['overflow'] = self.overflow
state_dict['base_optimizer_state'] = self.optimizer.state_dict()
state_dict[
'local_sub_partitions_of_fp32_groups'] = self.local_sub_partitions_of_fp32_groups
return state_dict
def _elastic_state_dict(self):
"""
Returns a dict that can be loaded for elastic training with different DP degree
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['base_optimizer_state'] = self._get_base_optimizer_state() state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES
...@@ -833,13 +864,40 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -833,13 +864,40 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
return state_dict return state_dict
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
if self.elastic_checkpoint:
return self._elastic_state_dict()
return self._rigid_state_dict()
# Extract the fp32 weights of the current rank from checkpoint by merging the
# sub partitions of communication intervals across ranks.
# Let sub_i_j = sub partition of rank i and comm interval j
# For 2 ranks and 2 comm intervals, checkpoints (minus padding) are as follows:
# rank 0 = [sub_0_0, sub_0_1]
# rank 1 = [sub_1_0, sub_1_1]
# Merge to get [sub_0_0, sub_1_0, sub_0_1, sub_1_1] => original un-padded flattened tensor.
def _retrieve_group_sub_partition_weights(self, all_partition_fp32_weights): def _retrieve_group_sub_partition_weights(self, all_partition_fp32_weights):
partition_id = dist.get_rank(group=self.dp_process_group) num_partitions = len(all_partition_fp32_weights)
num_comm_intervals = len(all_partition_fp32_weights[0])
num_sub_partitions = num_partitions * num_comm_intervals
all_sub_partition_weights = [None] * num_sub_partitions
all_sub_partition_weights = [] for rank, partition_weights in enumerate(all_partition_fp32_weights):
for partition_weights in all_partition_fp32_weights: for comm_idx, sub_partition_weights in enumerate(partition_weights):
for sub_partition_weights in partition_weights: #all_sub_partition_weights.append(sub_partition_weights)
all_sub_partition_weights.append(sub_partition_weights) sub_partition_idx = (comm_idx * num_partitions) + rank
all_sub_partition_weights[sub_partition_idx] = sub_partition_weights
flat_merged_weights = flatten_dense_tensors_sub_partition_aligned( flat_merged_weights = flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_sub_partition_weights, tensor_list=all_sub_partition_weights,
...@@ -855,6 +913,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -855,6 +913,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
dp_process_group=self.dp_process_group dp_process_group=self.dp_process_group
) )
partition_id = dist.get_rank(group=self.dp_process_group)
return [sub_partition for sub_partition in dp_sub_partitions[partition_id]] return [sub_partition for sub_partition in dp_sub_partitions[partition_id]]
# Restore base optimizer fp32 weights from checkpoint by: # Restore base optimizer fp32 weights from checkpoint by:
...@@ -903,13 +962,18 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -903,13 +962,18 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
# 3) Return state corresponding to local partition # 3) Return state corresponding to local partition
def _retrieve_group_optimizer_states(self, all_partition_states): def _retrieve_group_optimizer_states(self, all_partition_states):
merged_optimizer_states = {} merged_optimizer_states = {}
for partition_state in all_partition_states: num_partitions = len(all_partition_states)
for sub_partition_state in partition_state: num_comm_intervals = len(all_partition_states[0])
num_sub_partitions = num_partitions * num_comm_intervals
for rank, partition_state in enumerate(all_partition_states):
for comm_idx, sub_partition_state in enumerate(partition_state):
for key, value in sub_partition_state.items(): for key, value in sub_partition_state.items():
if not key in merged_optimizer_states.keys(): if not key in merged_optimizer_states.keys():
merged_optimizer_states[key] = [value] merged_optimizer_states[key] = [None] * num_sub_partitions
else:
merged_optimizer_states[key].append(value) sub_partition_idx = (comm_idx * num_partitions) + rank
merged_optimizer_states[key][sub_partition_idx] = value
group_optimizer_states = {} group_optimizer_states = {}
for key, value in merged_optimizer_states.items(): for key, value in merged_optimizer_states.items():
...@@ -950,10 +1014,23 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -950,10 +1014,23 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
def refresh_fp32_params(self): def refresh_fp32_params(self):
self._restore_from_fp16_weights() self._restore_from_fp16_weights()
def load_state_dict(self, def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
state_dict_list,
load_optimizer_states=True, # I think it should actually be ok to reload the optimizer before the model.
load_from_fp32_weights=False): self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['base_optimizer_state'])
for curr_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, state_dict['local_sub_partitions_of_fp32_groups']):
for curr_param, saved_param in zip(curr_group, saved_group):
curr_param.data.copy_(saved_param.data)
def _elastic_load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
""" """
Loads a state_dict created by an earlier call to state_dict(). Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
...@@ -981,3 +1058,46 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -981,3 +1058,46 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
self._restore_from_fp32_weights(state_dict_list) self._restore_from_fp32_weights(state_dict_list)
else: else:
self._restore_from_fp16_weights() self._restore_from_fp16_weights()
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
if self.elastic_checkpoint:
self._elastic_load_state_dict(state_dict_list,
load_optimizer_states,
load_from_fp32_weights)
else:
self._rigid_load_state_dict(
state_dict_list[dist.get_rank(group=self.dp_process_group)],
load_optimizer_states)
def _dump_optimizer_state(self, message):
logger.info(f'{message}')
for i, group in enumerate(self.optimizer.param_groups):
for j, param in enumerate(group['params']):
for key, value in self.optimizer.state[param].items():
t_stats = [
value.min(),
value.max(),
(value.max() - value.min()),
value.mean()
]
stats = [float(t) for t in t_stats]
logger.info(
f'group/param/key/min/max/delta/mean = {i}, {j}, {key}: {stats}')
...@@ -256,19 +256,16 @@ def test_checkpoint_fused_optimizer(tmpdir): ...@@ -256,19 +256,16 @@ def test_checkpoint_fused_optimizer(tmpdir):
load_optimizer_states=False) load_optimizer_states=False)
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', @pytest.mark.parametrize('zero_stage, use_cpu_offload',
[ [
(1, (1,
False, False),
'Adam'),
(2, (2,
False, False),
'Adam'),
(2, (2,
True, True),
'Adam'),
]) ])
def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible") pytest.skip("cpu-adam is not compatible")
...@@ -276,7 +273,7 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt ...@@ -276,7 +273,7 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
"train_batch_size": 2, "train_batch_size": 2,
"steps_per_print": 1, "steps_per_print": 1,
"optimizer": { "optimizer": {
"type": adam_optimizer, "type": 'Adam',
"params": { "params": {
"lr": 0.00015, "lr": 0.00015,
"betas": [0.8, "betas": [0.8,
...@@ -312,22 +309,16 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt ...@@ -312,22 +309,16 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
load_optimizer_states=True) load_optimizer_states=True)
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', @pytest.mark.parametrize('zero_stage, use_cpu_offload',
[ [
(1, (1,
False, False),
"Adam"),
(2, (2,
False, False),
"Adam"),
(2, (2,
True, True),
'Adam'),
]) ])
def test_checkpoint_zero_no_optimizer(tmpdir, def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload):
zero_stage,
use_cpu_offload,
adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible") pytest.skip("cpu-adam is not compatible")
...@@ -335,7 +326,7 @@ def test_checkpoint_zero_no_optimizer(tmpdir, ...@@ -335,7 +326,7 @@ def test_checkpoint_zero_no_optimizer(tmpdir,
"train_batch_size": 2, "train_batch_size": 2,
"steps_per_print": 1, "steps_per_print": 1,
"optimizer": { "optimizer": {
"type": adam_optimizer, "type": 'Adam',
"params": { "params": {
"lr": 0.00015, "lr": 0.00015,
"betas": [0.8, "betas": [0.8,
...@@ -374,22 +365,18 @@ def test_checkpoint_zero_no_optimizer(tmpdir, ...@@ -374,22 +365,18 @@ def test_checkpoint_zero_no_optimizer(tmpdir,
load_optimizer_states=False) load_optimizer_states=False)
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', @pytest.mark.parametrize('zero_stage, use_cpu_offload',
[ [
(0, (0,
False, False),
'Adam'),
(1, (1,
False, False),
'Adam'),
(2, (2,
False, False),
'Adam'),
(2, (2,
True, True),
'Adam'),
]) ])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible") pytest.skip("cpu-adam is not compatible")
...@@ -397,7 +384,7 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim ...@@ -397,7 +384,7 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim
"train_batch_size": 2, "train_batch_size": 2,
"steps_per_print": 1, "steps_per_print": 1,
"optimizer": { "optimizer": {
"type": adam_optimizer, "type": 'Adam',
"params": { "params": {
"lr": 0.00015, "lr": 0.00015,
"betas": [0.8, "betas": [0.8,
...@@ -448,22 +435,18 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim ...@@ -448,22 +435,18 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim
load_lr_scheduler_states=True) load_lr_scheduler_states=True)
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', @pytest.mark.parametrize('zero_stage, use_cpu_offload',
[ [
(0, (0,
False, False),
'Adam'),
(1, (1,
False, False),
'Adam'),
(2, (2,
False, False),
'Adam'),
(2, (2,
True, True),
'Adam'),
]) ])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible") pytest.skip("cpu-adam is not compatible")
...@@ -471,7 +454,7 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_op ...@@ -471,7 +454,7 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_op
"train_batch_size": 2, "train_batch_size": 2,
"steps_per_print": 1, "steps_per_print": 1,
"optimizer": { "optimizer": {
"type": adam_optimizer, "type": 'Adam',
"params": { "params": {
"lr": 1e-5 "lr": 1e-5
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment