Unverified Commit b2c87edf authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

Fix global_steps checkpoint loading. (#139)

parent 4f42bbb0
......@@ -8,7 +8,7 @@ efficient, and effective.
<p align="center"><i><b>5x Faster Training</b></i></p>
<p align="center"><i><b>Minimal Code Change</b></i></p>
DeepSpeed can train DL models with over a hundred billion parameters on current
DeepSpeed can train deep learning models with over a hundred billion parameters on current
generation of GPU clusters, while achieving over 5x in system performance
compared to the state-of-art. Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
......
......@@ -1021,7 +1021,7 @@ class DeepSpeedLight(Module):
'optimizer',
'csr_tensor_module_names',
'skipped_steps',
'global_step'
'global_steps'
]
client_state = {
key: value
......
......@@ -13,7 +13,33 @@ from common import distributed_test
from simple_model import SimpleModel, random_dataloader, args_from_dict
def compare_deepspeed_states(saved_model, loaded_model):
# These are compared in more depth in other places
assert hasattr(loaded_model, 'module')
assert saved_model.csr_tensor_module_names == loaded_model.csr_tensor_module_names
assert saved_model.skipped_steps == loaded_model.skipped_steps
assert saved_model.global_steps == loaded_model.global_steps
def compare_lr_scheduler_states(saved_model, loaded_model):
if saved_model.lr_scheduler is None:
assert loaded_model.lr_scheduler is None
return
saved = saved_model.lr_scheduler.state_dict()
loaded = loaded_model.lr_scheduler.state_dict()
assert sorted(saved.keys()) == sorted(loaded.keys())
for key in saved.keys():
if isinstance(saved[key], torch.Tensor):
assert torch.equal(saved[key], loaded[key])
else:
assert saved[key] == loaded[key]
def compare_model_states(saved_model, loaded_model):
compare_deepspeed_states(saved_model, loaded_model)
for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()):
assert torch.allclose(p0,p1,atol=1e-07), f"FP16 model state {p0} is not equal to {p1}"
......@@ -37,6 +63,8 @@ def compare_model_states(saved_model, loaded_model):
def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
compare_model_states(saved_model, loaded_model)
assert hasattr(loaded_model, 'optimizer')
for state0, state1 in zip(saved_model.optimizer.optimizer.state.values(),
loaded_model.optimizer.optimizer.state.values()):
for s0, s1 in zip(state0.values(), state1.values()):
......@@ -46,7 +74,8 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
assert s0 == s1
def checkpoint_correctness_verification(args,
def checkpoint_correctness_verification(save_folder,
args,
model,
hidden_dim,
load_optimizer_states=True):
......@@ -65,7 +94,6 @@ def checkpoint_correctness_verification(args,
trained_model = ds_model
save_folder = 'saved_checkpoint'
save_tag = '1'
trained_model.save_checkpoint(save_folder, save_tag)
......@@ -78,6 +106,8 @@ def checkpoint_correctness_verification(args,
save_tag,
load_optimizer_states=load_optimizer_states)
compare_lr_scheduler_states(trained_model, loaded_model)
if load_optimizer_states:
compare_optimizer_states(trained_model, loaded_model, hidden_dim)
else:
......@@ -97,6 +127,22 @@ def test_checkpoint_unfused_optimizer(tmpdir):
},
"fp16": {
"enabled": True
},
"scheduler": {
"type": "OneCycle",
"params": {
"cycle_first_step_size": 1000,
"cycle_first_stair_count": 500,
"cycle_second_step_size": 1000,
"cycle_second_stair_count": 500,
"decay_step_size": 1000,
"cycle_min_lr": 0.0001,
"cycle_max_lr": 0.0010,
"decay_lr_rate": 0.001,
"cycle_min_mom": 0.85,
"cycle_max_mom": 0.99,
"decay_mom_rate": 0.0
}
}
}
......@@ -110,7 +156,8 @@ def test_checkpoint_unfused_optimizer(tmpdir):
model,
hidden_dim,
load_optimizer_states):
checkpoint_correctness_verification(args,
checkpoint_correctness_verification(tmpdir,
args,
model,
hidden_dim,
load_optimizer_states=load_optimizer_states)
......@@ -151,7 +198,8 @@ def test_checkpoint_fused_optimizer(tmpdir):
@distributed_test(world_size=[2])
def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(args,
checkpoint_correctness_verification(tmpdir,
args,
model,
hidden_dim,
load_optimizer_states=load_optimizer_states)
......@@ -192,7 +240,8 @@ def test_checkpoint_zero_optimizer(tmpdir):
@distributed_test(world_size=[2])
def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(args,
checkpoint_correctness_verification(tmpdir,
args,
model,
hidden_dim,
load_optimizer_states=load_optimizer_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment