"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "06608f847c1e58b16a2c3eb876b1afaa7cc7a7f7"
Unverified Commit b2c87edf authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

Fix global_steps checkpoint loading. (#139)

parent 4f42bbb0
...@@ -8,7 +8,7 @@ efficient, and effective. ...@@ -8,7 +8,7 @@ efficient, and effective.
<p align="center"><i><b>5x Faster Training</b></i></p> <p align="center"><i><b>5x Faster Training</b></i></p>
<p align="center"><i><b>Minimal Code Change</b></i></p> <p align="center"><i><b>Minimal Code Change</b></i></p>
DeepSpeed can train DL models with over a hundred billion parameters on current DeepSpeed can train deep learning models with over a hundred billion parameters on current
generation of GPU clusters, while achieving over 5x in system performance generation of GPU clusters, while achieving over 5x in system performance
compared to the state-of-art. Early adopters of DeepSpeed have already produced compared to the state-of-art. Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called a language model (LM) with over 17B parameters called
......
...@@ -1021,7 +1021,7 @@ class DeepSpeedLight(Module): ...@@ -1021,7 +1021,7 @@ class DeepSpeedLight(Module):
'optimizer', 'optimizer',
'csr_tensor_module_names', 'csr_tensor_module_names',
'skipped_steps', 'skipped_steps',
'global_step' 'global_steps'
] ]
client_state = { client_state = {
key: value key: value
......
...@@ -13,7 +13,33 @@ from common import distributed_test ...@@ -13,7 +13,33 @@ from common import distributed_test
from simple_model import SimpleModel, random_dataloader, args_from_dict from simple_model import SimpleModel, random_dataloader, args_from_dict
def compare_deepspeed_states(saved_model, loaded_model):
# These are compared in more depth in other places
assert hasattr(loaded_model, 'module')
assert saved_model.csr_tensor_module_names == loaded_model.csr_tensor_module_names
assert saved_model.skipped_steps == loaded_model.skipped_steps
assert saved_model.global_steps == loaded_model.global_steps
def compare_lr_scheduler_states(saved_model, loaded_model):
if saved_model.lr_scheduler is None:
assert loaded_model.lr_scheduler is None
return
saved = saved_model.lr_scheduler.state_dict()
loaded = loaded_model.lr_scheduler.state_dict()
assert sorted(saved.keys()) == sorted(loaded.keys())
for key in saved.keys():
if isinstance(saved[key], torch.Tensor):
assert torch.equal(saved[key], loaded[key])
else:
assert saved[key] == loaded[key]
def compare_model_states(saved_model, loaded_model): def compare_model_states(saved_model, loaded_model):
compare_deepspeed_states(saved_model, loaded_model)
for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()): for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()):
assert torch.allclose(p0,p1,atol=1e-07), f"FP16 model state {p0} is not equal to {p1}" assert torch.allclose(p0,p1,atol=1e-07), f"FP16 model state {p0} is not equal to {p1}"
...@@ -37,6 +63,8 @@ def compare_model_states(saved_model, loaded_model): ...@@ -37,6 +63,8 @@ def compare_model_states(saved_model, loaded_model):
def compare_optimizer_states(saved_model, loaded_model, hidden_dim): def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
compare_model_states(saved_model, loaded_model) compare_model_states(saved_model, loaded_model)
assert hasattr(loaded_model, 'optimizer')
for state0, state1 in zip(saved_model.optimizer.optimizer.state.values(), for state0, state1 in zip(saved_model.optimizer.optimizer.state.values(),
loaded_model.optimizer.optimizer.state.values()): loaded_model.optimizer.optimizer.state.values()):
for s0, s1 in zip(state0.values(), state1.values()): for s0, s1 in zip(state0.values(), state1.values()):
...@@ -46,7 +74,8 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim): ...@@ -46,7 +74,8 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
assert s0 == s1 assert s0 == s1
def checkpoint_correctness_verification(args, def checkpoint_correctness_verification(save_folder,
args,
model, model,
hidden_dim, hidden_dim,
load_optimizer_states=True): load_optimizer_states=True):
...@@ -65,7 +94,6 @@ def checkpoint_correctness_verification(args, ...@@ -65,7 +94,6 @@ def checkpoint_correctness_verification(args,
trained_model = ds_model trained_model = ds_model
save_folder = 'saved_checkpoint'
save_tag = '1' save_tag = '1'
trained_model.save_checkpoint(save_folder, save_tag) trained_model.save_checkpoint(save_folder, save_tag)
...@@ -78,6 +106,8 @@ def checkpoint_correctness_verification(args, ...@@ -78,6 +106,8 @@ def checkpoint_correctness_verification(args,
save_tag, save_tag,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states)
compare_lr_scheduler_states(trained_model, loaded_model)
if load_optimizer_states: if load_optimizer_states:
compare_optimizer_states(trained_model, loaded_model, hidden_dim) compare_optimizer_states(trained_model, loaded_model, hidden_dim)
else: else:
...@@ -97,6 +127,22 @@ def test_checkpoint_unfused_optimizer(tmpdir): ...@@ -97,6 +127,22 @@ def test_checkpoint_unfused_optimizer(tmpdir):
}, },
"fp16": { "fp16": {
"enabled": True "enabled": True
},
"scheduler": {
"type": "OneCycle",
"params": {
"cycle_first_step_size": 1000,
"cycle_first_stair_count": 500,
"cycle_second_step_size": 1000,
"cycle_second_stair_count": 500,
"decay_step_size": 1000,
"cycle_min_lr": 0.0001,
"cycle_max_lr": 0.0010,
"decay_lr_rate": 0.001,
"cycle_min_mom": 0.85,
"cycle_max_mom": 0.99,
"decay_mom_rate": 0.0
}
} }
} }
...@@ -110,7 +156,8 @@ def test_checkpoint_unfused_optimizer(tmpdir): ...@@ -110,7 +156,8 @@ def test_checkpoint_unfused_optimizer(tmpdir):
model, model,
hidden_dim, hidden_dim,
load_optimizer_states): load_optimizer_states):
checkpoint_correctness_verification(args, checkpoint_correctness_verification(tmpdir,
args,
model, model,
hidden_dim, hidden_dim,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states)
...@@ -151,7 +198,8 @@ def test_checkpoint_fused_optimizer(tmpdir): ...@@ -151,7 +198,8 @@ def test_checkpoint_fused_optimizer(tmpdir):
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states): def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(args, checkpoint_correctness_verification(tmpdir,
args,
model, model,
hidden_dim, hidden_dim,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states)
...@@ -192,7 +240,8 @@ def test_checkpoint_zero_optimizer(tmpdir): ...@@ -192,7 +240,8 @@ def test_checkpoint_zero_optimizer(tmpdir):
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states): def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(args, checkpoint_correctness_verification(tmpdir,
args,
model, model,
hidden_dim, hidden_dim,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment