Unverified Commit f2ac7eaf authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

ZeRO-2 (#217)



Updates for ZeRO stage 2 + ZeRO stage 1 w. RS
Co-authored-by: default avatarTunji Ruwase <olruwase@microsoft.com>
Co-authored-by: default avatarSamyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: default avatarShaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: default avatarElton Zheng <eltonz@microsoft.com>
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: default avataryuxionghe <yuxhe@microsoft.com>
Co-authored-by: default avatarArash Ashari <arashari@microsoft.com>
parent c61e23b4
......@@ -8,7 +8,7 @@ from torch.multiprocessing import Process
import pytest
# Worker timeout *after* the first worker has completed.
DEEPSPEED_UNIT_WORKER_TIMEOUT = 10
DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
def distributed_test(world_size=2, backend='nccl'):
......
import torch
import deepspeed
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
......@@ -9,6 +10,7 @@ import argparse
import pytest
import json
import os
import numbers
from common import distributed_test
from simple_model import SimpleModel, random_dataloader, args_from_dict
......@@ -22,21 +24,6 @@ def compare_deepspeed_states(saved_model, loaded_model):
assert saved_model.global_steps == loaded_model.global_steps
def compare_lr_scheduler_states(saved_model, loaded_model):
if saved_model.lr_scheduler is None:
assert loaded_model.lr_scheduler is None
return
saved = saved_model.lr_scheduler.state_dict()
loaded = loaded_model.lr_scheduler.state_dict()
assert sorted(saved.keys()) == sorted(loaded.keys())
for key in saved.keys():
if isinstance(saved[key], torch.Tensor):
assert torch.equal(saved[key], loaded[key])
else:
assert saved[key] == loaded[key]
def compare_model_states(saved_model, loaded_model):
compare_deepspeed_states(saved_model, loaded_model)
......@@ -47,6 +34,11 @@ def compare_model_states(saved_model, loaded_model):
for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
assert torch.allclose(p0,p1,atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
for p0, p1 in zip(partition0, partition1):
assert torch.allclose(p0,p1,atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_Optimizer):
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
assert torch.allclose(p0,p1,atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
......@@ -61,10 +53,6 @@ def compare_model_states(saved_model, loaded_model):
def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
compare_model_states(saved_model, loaded_model)
assert hasattr(loaded_model, 'optimizer')
for state0, state1 in zip(saved_model.optimizer.optimizer.state.values(),
loaded_model.optimizer.optimizer.state.values()):
for s0, s1 in zip(state0.values(), state1.values()):
......@@ -74,11 +62,35 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
assert s0 == s1
def checkpoint_correctness_verification(save_folder,
args,
def compare_lr_scheduler_states(saved_model, loaded_model):
assert hasattr(saved_model, 'lr_scheduler')
assert hasattr(loaded_model, 'lr_scheduler')
saved_scheduler = saved_model.lr_scheduler
loaded_scheduler = loaded_model.lr_scheduler
assert hasattr(saved_scheduler, 'state_dict')
assert hasattr(loaded_scheduler, 'state_dict')
saved_sd = saved_scheduler.state_dict()
loaded_sd = loaded_scheduler.state_dict()
print(f"saved_sd = {saved_sd}")
print(f"loaded_sd = {loaded_sd}")
assert saved_sd.keys() == loaded_sd.keys()
for state0, state1 in zip(saved_sd.values(), loaded_sd.values()):
if isinstance(state0, numbers.Number) and isinstance(state1, numbers.Number):
assert state0 == state1
def checkpoint_correctness_verification(args,
model,
hidden_dim,
load_optimizer_states=True):
tmpdir,
load_optimizer_states=False,
load_lr_scheduler_states=False):
ds_model, _, _,_ = deepspeed.initialize(args=args,
model=model,
......@@ -94,6 +106,7 @@ def checkpoint_correctness_verification(save_folder,
trained_model = ds_model
save_folder = os.path.join(tmpdir, 'saved_checkpoint')
save_tag = '1'
trained_model.save_checkpoint(save_folder, save_tag)
......@@ -104,14 +117,16 @@ def checkpoint_correctness_verification(save_folder,
loaded_model.load_checkpoint(save_folder,
save_tag,
load_optimizer_states=load_optimizer_states)
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
compare_lr_scheduler_states(trained_model, loaded_model)
compare_model_states(trained_model, loaded_model)
if load_optimizer_states:
compare_optimizer_states(trained_model, loaded_model, hidden_dim)
else:
compare_model_states(trained_model, loaded_model)
if load_lr_scheduler_states:
compare_lr_scheduler_states(trained_model, loaded_model)
def test_checkpoint_unfused_optimizer(tmpdir):
......@@ -156,10 +171,10 @@ def test_checkpoint_unfused_optimizer(tmpdir):
model,
hidden_dim,
load_optimizer_states):
checkpoint_correctness_verification(tmpdir,
args,
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_unfused_optimizer(args=args,
......@@ -198,10 +213,10 @@ def test_checkpoint_fused_optimizer(tmpdir):
@distributed_test(world_size=[2])
def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(tmpdir,
args,
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_fused_optimizer(args=args,
......@@ -214,7 +229,8 @@ def test_checkpoint_fused_optimizer(tmpdir):
load_optimizer_states=False)
def test_checkpoint_zero_optimizer(tmpdir):
@pytest.mark.parametrize("zero_stage", [1, 2])
def test_checkpoint_zero_optimizer(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
......@@ -231,7 +247,9 @@ def test_checkpoint_zero_optimizer(tmpdir):
"fp16": {
"enabled": True
},
"zero_optimization": True
"zero_optimization": {
"stage": zero_stage
},
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
......@@ -240,17 +258,165 @@ def test_checkpoint_zero_optimizer(tmpdir):
@distributed_test(world_size=[2])
def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(tmpdir,
args,
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_zero_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=True)
_test_checkpoint_zero_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False)
@pytest.mark.parametrize("zero_stage", [1, 2])
def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"betas": [0.8,
0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
},
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_zero_no_optimizer(args,
model,
hidden_dim,
load_optimizer_states):
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_zero_no_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False)
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"betas": [0.8,
0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_lr_scheduler(args,
model,
hidden_dim,
load_optimizer_states,
load_lr_scheduler_states):
checkpoint_correctness_verification(
args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
_test_checkpoint_lr_scheduler(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False,
load_lr_scheduler_states=True)
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-5
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_no_lr_scheduler(args,
model,
hidden_dim,
load_optimizer_states,
load_lr_scheduler_states):
checkpoint_correctness_verification(
args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
_test_checkpoint_no_lr_scheduler(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False,
load_lr_scheduler_states=False)
......@@ -17,3 +17,19 @@ def test_only_required_fields(tmpdir):
assert run_cfg.train_batch_size == 64
assert run_cfg.train_micro_batch_size_per_gpu == 64
assert run_cfg.gradient_accumulation_steps == 1
def test_config_duplicate_key(tmpdir):
config_dict = '''
{
"train_batch_size": 24,
"train_batch_size": 24,
}
'''
config_path = os.path.join(tmpdir, 'temp_config.json')
with open(config_path, 'w') as jf:
jf.write("%s" % config_dict)
with pytest.raises(ValueError):
run_cfg = ds_config.DeepSpeedConfig(config_path)
......@@ -144,7 +144,8 @@ def test_adamw_fp16_empty_grad(tmpdir):
_test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
def test_adam_fp16_onecycle_compatibility(tmpdir):
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
......@@ -171,15 +172,18 @@ def test_adam_fp16_onecycle_compatibility(tmpdir):
"fp16": {
"enabled": True
},
"zero_optimization": False
"zero_optimization": {
"stage": zero_stage
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
@distributed_test(world_size=[1])
def _test_adam_fp16_onecycle_compatibility(args, model, hidden_dim):
def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
......@@ -192,12 +196,15 @@ def test_adam_fp16_onecycle_compatibility(tmpdir):
model.backward(loss)
model.step()
_test_adam_fp16_onecycle_compatibility(args=args, model=model, hidden_dim=hidden_dim)
_test_adam_fp16_zero_onecycle_compatibility(args=args,
model=model,
hidden_dim=hidden_dim)
def test_adam_fp16_zero_onecycle_compatibility(tmpdir):
@pytest.mark.parametrize("zero_stage", [1, 2])
def test_zero_static_scale(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 1,
"train_batch_size": 4,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
......@@ -205,37 +212,31 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir):
"lr": 0.00015
}
},
"scheduler": {
"type": "OneCycle",
"params": {
"cycle_first_step_size": 16000,
"cycle_first_stair_count": 8000,
"decay_step_size": 16000,
"cycle_min_lr": 1e-06,
"cycle_max_lr": 3e-05,
"decay_lr_rate": 1e-07,
"cycle_min_mom": 0.85,
"cycle_max_mom": 0.99,
"decay_mom_rate": 0.0
}
},
"fp16": {
"enabled": True
"enabled": True,
"loss_scale": 138.
},
"zero_optimization": True
"zero_optimization": {
"stage": zero_stage
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
@distributed_test(world_size=2)
def _test_zero_static_scale(args):
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
model, optim, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
@distributed_test(world_size=[1])
def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
# Ensure the static scaler is configured.
assert optim.dynamic_loss_scale == False
assert optim.loss_scaler.loss_scale == 138.
# Now make sure things work..
data_loader = random_dataloader(model=model,
total_samples=50,
total_samples=10,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
......@@ -243,12 +244,10 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir):
model.backward(loss)
model.step()
_test_adam_fp16_zero_onecycle_compatibility(args=args,
model=model,
hidden_dim=hidden_dim)
_test_zero_static_scale(args)
def test_zero_static_scale(tmpdir):
def test_zero_static_scale_deprecated_format(tmpdir):
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
......@@ -291,14 +290,17 @@ def test_zero_static_scale(tmpdir):
_test_zero_static_scale(args)
def test_zero_allow_untested_optimizer(tmpdir):
@pytest.mark.parametrize("zero_stage", [1, 2])
def test_zero_allow_untested_optimizer(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
"fp16": {
"enabled": True,
},
"zero_optimization": True,
"zero_optimization": {
"stage": zero_stage
},
"zero_allow_untested_optimizer": False
}
args = args_from_dict(tmpdir, config_dict)
......@@ -317,31 +319,34 @@ def test_zero_allow_untested_optimizer(tmpdir):
_test_zero_allow_untested_optimizer(args)
def test_zero_empty_partition(tmpdir):
config_dict = {
"train_batch_size": 3,
"fp16": {
"enabled": True
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"zero_optimization": True
}
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=[3])
def _test_zero_empty_partition(args):
hidden_dim = 1
model = SimpleModel(hidden_dim)
# Ensure model has 2 parameters, to cause empty partition with DP=3
assert len(list(model.parameters())) == 2
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
model.step()
_test_zero_empty_partition(args)
# @pytest.mark.parametrize("zero_stage", [1])
# def test_zero_empty_partition(tmpdir, zero_stage):
# config_dict = {
# "train_batch_size": 3,
# "fp16": {
# "enabled": True
# },
# "optimizer": {
# "type": "Adam",
# "params": {
# "lr": 0.00015
# }
# },
# "zero_optimization": {
# "stage": zero_stage
# }
# }
# args = args_from_dict(tmpdir, config_dict)
# @distributed_test(world_size=[3])
# def _test_zero_empty_partition(args):
# hidden_dim = 1
# model = SimpleModel(hidden_dim)
# # Ensure model has 2 parameters, to cause empty partition with DP=3
# assert len(list(model.parameters())) == 2
# model, _, _, _ = deepspeed.initialize(args=args,
# model=model,
# model_parameters=model.parameters())
# model.step()
# _test_zero_empty_partition(args)
......@@ -73,7 +73,7 @@ def test_two_output_model(tmpdir):
summed_loss = sum(loss_tuple)
scaled_loss = model.backward(summed_loss)
expected_scaled_loss = summed_loss / gradient_accumulation_steps
expected_scaled_loss = summed_loss.float() / gradient_accumulation_steps
assert scaled_loss.item() == approx(expected_scaled_loss.item())
model.step()
......@@ -131,7 +131,7 @@ def test_three_output_model(tmpdir):
summed_loss = sum(loss_tuple)
scaled_loss = model.backward(summed_loss)
expected_scaled_loss = summed_loss / gradient_accumulation_steps
expected_scaled_loss = summed_loss.float() / gradient_accumulation_steps
assert scaled_loss.item() == approx(expected_scaled_loss.item())
model.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment