Unverified Commit fa87a73a authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Fix ZeRO3 save_checkpoint (#857)


Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 871f3048
......@@ -2269,7 +2269,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(
single_grad_partition.numel(), self.partition_size[sub_group_id], sub_group_id, partition_id)
single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id)
self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition
......@@ -2638,14 +2638,12 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
def _set_fp32_optimizer_param_groups(self):
for sub_group_id, _ in enumerate(self.fp16_groups):
param_group_id = self.sub_group_to_group_id[sub_group_id]
self.optimizer.param_groups[param_group_id]['params'] = [
self.fp32_partitioned_groups_flat[sub_group_id]
]
self.optimizer.param_groups[param_group_id]['params'].append(
self.fp32_partitioned_groups_flat[sub_group_id])
def _clear_fp32_optimizer_param_groups(self):
for sub_group_id, _ in enumerate(self.fp16_groups):
param_group_id = self.sub_group_to_group_id[sub_group_id]
self.optimizer.param_groups[param_group_id]['params'] = []
for param_group in self.optimizer.param_groups:
param_group['params'] = []
def _rigid_state_dict(self):
state_dict = {}
......
......@@ -47,7 +47,7 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
if FP16_DeepSpeedZeroOptimizer_Stage3 is not None and isinstance(
saved_model.optimizer,
FP16_DeepSpeedZeroOptimizer_Stage3):
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
for p0, p1 in zip(saved_model.optimizer.fp32_partitioned_groups_flat, loaded_model.optimizer.fp32_partitioned_groups_flat):
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer):
......@@ -303,12 +303,13 @@ def test_checkpoint_fused_optimizer(tmpdir):
'deepspeed_adam'),
(3,
False,
'Adam')])
'Adam'),
(3,
True,
'deepspeed_adam')])
def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip('Skip checkpointing tests for ZeRO3')
config_dict = {
"train_batch_size": 2,
......@@ -324,8 +325,10 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
}
},
"fp16": {
"enabled": True
"enabled": True,
"initial_scale_power": 8
},
"wall_clock_breakdown": True,
"zero_optimization": {
"stage": zero_stage,
"cpu_offload": use_cpu_offload
......@@ -340,9 +343,7 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
hidden_dim,
load_optimizer_states):
if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
with deepspeed.ScatteredParameters(zero_modules=True):
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
......@@ -371,15 +372,16 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
'deepspeed_adam'),
(3,
False,
'Adam')])
'Adam'),
(3,
True,
'deepspeed_adam')])
def test_checkpoint_zero_no_optimizer(tmpdir,
zero_stage,
use_cpu_offload,
adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip('Skip checkpointing tests for ZeRO3')
config_dict = {
"train_batch_size": 2,
......@@ -413,7 +415,7 @@ def test_checkpoint_zero_no_optimizer(tmpdir,
if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
with deepspeed.ScatteredParameters(zero_modules=True):
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
......@@ -445,12 +447,13 @@ def test_checkpoint_zero_no_optimizer(tmpdir,
'deepspeed_adam'),
(3,
False,
'Adam')])
'Adam'),
(3,
True,
'deepspeed_adam')])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip('Skip checkpointing tests for ZeRO3')
config_dict = {
"train_batch_size": 2,
......@@ -493,7 +496,7 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim
if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
with deepspeed.ScatteredParameters(zero_modules=True):
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
......@@ -526,14 +529,15 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim
(2,
True,
'deepspeed_adam'),
(3,
False,
'Adam'),
(3,
True,
'Adam')])
'deepspeed_adam')])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if zero_stage == 3:
pytest.skip('Skip checkpointing tests for ZeRO3')
config_dict = {
"train_batch_size": 2,
......@@ -570,7 +574,7 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_op
load_optimizer_states,
load_lr_scheduler_states):
if zero_stage == 3:
with deepspeed.ScatteredParameters(zero_modules=True):
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment