Unverified Commit f2ac7eaf authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

ZeRO-2 (#217)



Updates for ZeRO stage 2 + ZeRO stage 1 w. RS
Co-authored-by: default avatarTunji Ruwase <olruwase@microsoft.com>
Co-authored-by: default avatarSamyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: default avatarShaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: default avatarElton Zheng <eltonz@microsoft.com>
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: default avataryuxionghe <yuxhe@microsoft.com>
Co-authored-by: default avatarArash Ashari <arashari@microsoft.com>
parent c61e23b4
...@@ -8,7 +8,7 @@ from torch.multiprocessing import Process ...@@ -8,7 +8,7 @@ from torch.multiprocessing import Process
import pytest import pytest
# Worker timeout *after* the first worker has completed. # Worker timeout *after* the first worker has completed.
DEEPSPEED_UNIT_WORKER_TIMEOUT = 10 DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
def distributed_test(world_size=2, backend='nccl'): def distributed_test(world_size=2, backend='nccl'):
......
import torch import torch
import deepspeed import deepspeed
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.pt.fp16_optimizer import FP16_Optimizer from deepspeed.pt.fp16_optimizer import FP16_Optimizer
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
...@@ -9,6 +10,7 @@ import argparse ...@@ -9,6 +10,7 @@ import argparse
import pytest import pytest
import json import json
import os import os
import numbers
from common import distributed_test from common import distributed_test
from simple_model import SimpleModel, random_dataloader, args_from_dict from simple_model import SimpleModel, random_dataloader, args_from_dict
...@@ -22,21 +24,6 @@ def compare_deepspeed_states(saved_model, loaded_model): ...@@ -22,21 +24,6 @@ def compare_deepspeed_states(saved_model, loaded_model):
assert saved_model.global_steps == loaded_model.global_steps assert saved_model.global_steps == loaded_model.global_steps
def compare_lr_scheduler_states(saved_model, loaded_model):
if saved_model.lr_scheduler is None:
assert loaded_model.lr_scheduler is None
return
saved = saved_model.lr_scheduler.state_dict()
loaded = loaded_model.lr_scheduler.state_dict()
assert sorted(saved.keys()) == sorted(loaded.keys())
for key in saved.keys():
if isinstance(saved[key], torch.Tensor):
assert torch.equal(saved[key], loaded[key])
else:
assert saved[key] == loaded[key]
def compare_model_states(saved_model, loaded_model): def compare_model_states(saved_model, loaded_model):
compare_deepspeed_states(saved_model, loaded_model) compare_deepspeed_states(saved_model, loaded_model)
...@@ -47,6 +34,11 @@ def compare_model_states(saved_model, loaded_model): ...@@ -47,6 +34,11 @@ def compare_model_states(saved_model, loaded_model):
for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups): for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
assert torch.allclose(p0,p1,atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}" assert torch.allclose(p0,p1,atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
for p0, p1 in zip(partition0, partition1):
assert torch.allclose(p0,p1,atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_Optimizer): elif isinstance(saved_model.optimizer, FP16_Optimizer):
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat): for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
assert torch.allclose(p0,p1,atol=1e-07), f"FP32 model states {p0} is not equal to {p1}" assert torch.allclose(p0,p1,atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
...@@ -61,10 +53,6 @@ def compare_model_states(saved_model, loaded_model): ...@@ -61,10 +53,6 @@ def compare_model_states(saved_model, loaded_model):
def compare_optimizer_states(saved_model, loaded_model, hidden_dim): def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
compare_model_states(saved_model, loaded_model)
assert hasattr(loaded_model, 'optimizer')
for state0, state1 in zip(saved_model.optimizer.optimizer.state.values(), for state0, state1 in zip(saved_model.optimizer.optimizer.state.values(),
loaded_model.optimizer.optimizer.state.values()): loaded_model.optimizer.optimizer.state.values()):
for s0, s1 in zip(state0.values(), state1.values()): for s0, s1 in zip(state0.values(), state1.values()):
...@@ -74,11 +62,35 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim): ...@@ -74,11 +62,35 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
assert s0 == s1 assert s0 == s1
def checkpoint_correctness_verification(save_folder, def compare_lr_scheduler_states(saved_model, loaded_model):
args, assert hasattr(saved_model, 'lr_scheduler')
assert hasattr(loaded_model, 'lr_scheduler')
saved_scheduler = saved_model.lr_scheduler
loaded_scheduler = loaded_model.lr_scheduler
assert hasattr(saved_scheduler, 'state_dict')
assert hasattr(loaded_scheduler, 'state_dict')
saved_sd = saved_scheduler.state_dict()
loaded_sd = loaded_scheduler.state_dict()
print(f"saved_sd = {saved_sd}")
print(f"loaded_sd = {loaded_sd}")
assert saved_sd.keys() == loaded_sd.keys()
for state0, state1 in zip(saved_sd.values(), loaded_sd.values()):
if isinstance(state0, numbers.Number) and isinstance(state1, numbers.Number):
assert state0 == state1
def checkpoint_correctness_verification(args,
model, model,
hidden_dim, hidden_dim,
load_optimizer_states=True): tmpdir,
load_optimizer_states=False,
load_lr_scheduler_states=False):
ds_model, _, _,_ = deepspeed.initialize(args=args, ds_model, _, _,_ = deepspeed.initialize(args=args,
model=model, model=model,
...@@ -94,6 +106,7 @@ def checkpoint_correctness_verification(save_folder, ...@@ -94,6 +106,7 @@ def checkpoint_correctness_verification(save_folder,
trained_model = ds_model trained_model = ds_model
save_folder = os.path.join(tmpdir, 'saved_checkpoint')
save_tag = '1' save_tag = '1'
trained_model.save_checkpoint(save_folder, save_tag) trained_model.save_checkpoint(save_folder, save_tag)
...@@ -104,14 +117,16 @@ def checkpoint_correctness_verification(save_folder, ...@@ -104,14 +117,16 @@ def checkpoint_correctness_verification(save_folder,
loaded_model.load_checkpoint(save_folder, loaded_model.load_checkpoint(save_folder,
save_tag, save_tag,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
compare_lr_scheduler_states(trained_model, loaded_model) compare_model_states(trained_model, loaded_model)
if load_optimizer_states: if load_optimizer_states:
compare_optimizer_states(trained_model, loaded_model, hidden_dim) compare_optimizer_states(trained_model, loaded_model, hidden_dim)
else:
compare_model_states(trained_model, loaded_model) if load_lr_scheduler_states:
compare_lr_scheduler_states(trained_model, loaded_model)
def test_checkpoint_unfused_optimizer(tmpdir): def test_checkpoint_unfused_optimizer(tmpdir):
...@@ -156,10 +171,10 @@ def test_checkpoint_unfused_optimizer(tmpdir): ...@@ -156,10 +171,10 @@ def test_checkpoint_unfused_optimizer(tmpdir):
model, model,
hidden_dim, hidden_dim,
load_optimizer_states): load_optimizer_states):
checkpoint_correctness_verification(tmpdir, checkpoint_correctness_verification(args,
args,
model, model,
hidden_dim, hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states)
_test_checkpoint_unfused_optimizer(args=args, _test_checkpoint_unfused_optimizer(args=args,
...@@ -198,10 +213,10 @@ def test_checkpoint_fused_optimizer(tmpdir): ...@@ -198,10 +213,10 @@ def test_checkpoint_fused_optimizer(tmpdir):
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states): def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(tmpdir, checkpoint_correctness_verification(args,
args,
model, model,
hidden_dim, hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states)
_test_checkpoint_fused_optimizer(args=args, _test_checkpoint_fused_optimizer(args=args,
...@@ -214,7 +229,8 @@ def test_checkpoint_fused_optimizer(tmpdir): ...@@ -214,7 +229,8 @@ def test_checkpoint_fused_optimizer(tmpdir):
load_optimizer_states=False) load_optimizer_states=False)
def test_checkpoint_zero_optimizer(tmpdir): @pytest.mark.parametrize("zero_stage", [1, 2])
def test_checkpoint_zero_optimizer(tmpdir, zero_stage):
config_dict = { config_dict = {
"train_batch_size": 2, "train_batch_size": 2,
"steps_per_print": 1, "steps_per_print": 1,
...@@ -231,7 +247,9 @@ def test_checkpoint_zero_optimizer(tmpdir): ...@@ -231,7 +247,9 @@ def test_checkpoint_zero_optimizer(tmpdir):
"fp16": { "fp16": {
"enabled": True "enabled": True
}, },
"zero_optimization": True "zero_optimization": {
"stage": zero_stage
},
} }
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
...@@ -240,17 +258,165 @@ def test_checkpoint_zero_optimizer(tmpdir): ...@@ -240,17 +258,165 @@ def test_checkpoint_zero_optimizer(tmpdir):
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states): def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(tmpdir, checkpoint_correctness_verification(args,
args,
model, model,
hidden_dim, hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states)
_test_checkpoint_zero_optimizer(args=args, _test_checkpoint_zero_optimizer(args=args,
model=model, model=model,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
load_optimizer_states=True) load_optimizer_states=True)
_test_checkpoint_zero_optimizer(args=args,
model=model,
hidden_dim=hidden_dim, @pytest.mark.parametrize("zero_stage", [1, 2])
load_optimizer_states=False) def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"betas": [0.8,
0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
},
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_zero_no_optimizer(args,
model,
hidden_dim,
load_optimizer_states):
checkpoint_correctness_verification(args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states)
_test_checkpoint_zero_no_optimizer(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False)
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"betas": [0.8,
0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_lr_scheduler(args,
model,
hidden_dim,
load_optimizer_states,
load_lr_scheduler_states):
checkpoint_correctness_verification(
args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
_test_checkpoint_lr_scheduler(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False,
load_lr_scheduler_states=True)
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-5
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[2])
def _test_checkpoint_no_lr_scheduler(args,
model,
hidden_dim,
load_optimizer_states,
load_lr_scheduler_states):
checkpoint_correctness_verification(
args,
model,
hidden_dim,
tmpdir,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)
_test_checkpoint_no_lr_scheduler(args=args,
model=model,
hidden_dim=hidden_dim,
load_optimizer_states=False,
load_lr_scheduler_states=False)
...@@ -17,3 +17,19 @@ def test_only_required_fields(tmpdir): ...@@ -17,3 +17,19 @@ def test_only_required_fields(tmpdir):
assert run_cfg.train_batch_size == 64 assert run_cfg.train_batch_size == 64
assert run_cfg.train_micro_batch_size_per_gpu == 64 assert run_cfg.train_micro_batch_size_per_gpu == 64
assert run_cfg.gradient_accumulation_steps == 1 assert run_cfg.gradient_accumulation_steps == 1
def test_config_duplicate_key(tmpdir):
config_dict = '''
{
"train_batch_size": 24,
"train_batch_size": 24,
}
'''
config_path = os.path.join(tmpdir, 'temp_config.json')
with open(config_path, 'w') as jf:
jf.write("%s" % config_dict)
with pytest.raises(ValueError):
run_cfg = ds_config.DeepSpeedConfig(config_path)
...@@ -144,7 +144,8 @@ def test_adamw_fp16_empty_grad(tmpdir): ...@@ -144,7 +144,8 @@ def test_adamw_fp16_empty_grad(tmpdir):
_test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim) _test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim)
def test_adam_fp16_onecycle_compatibility(tmpdir): @pytest.mark.parametrize("zero_stage", [0, 1, 2])
def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage):
config_dict = { config_dict = {
"train_batch_size": 1, "train_batch_size": 1,
"steps_per_print": 1, "steps_per_print": 1,
...@@ -171,15 +172,18 @@ def test_adam_fp16_onecycle_compatibility(tmpdir): ...@@ -171,15 +172,18 @@ def test_adam_fp16_onecycle_compatibility(tmpdir):
"fp16": { "fp16": {
"enabled": True "enabled": True
}, },
"zero_optimization": False "zero_optimization": {
"stage": zero_stage
}
} }
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim, empty_grad=True)
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_adam_fp16_onecycle_compatibility(args, model, hidden_dim): def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args, model, _, _,_ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
...@@ -192,12 +196,15 @@ def test_adam_fp16_onecycle_compatibility(tmpdir): ...@@ -192,12 +196,15 @@ def test_adam_fp16_onecycle_compatibility(tmpdir):
model.backward(loss) model.backward(loss)
model.step() model.step()
_test_adam_fp16_onecycle_compatibility(args=args, model=model, hidden_dim=hidden_dim) _test_adam_fp16_zero_onecycle_compatibility(args=args,
model=model,
hidden_dim=hidden_dim)
def test_adam_fp16_zero_onecycle_compatibility(tmpdir): @pytest.mark.parametrize("zero_stage", [1, 2])
def test_zero_static_scale(tmpdir, zero_stage):
config_dict = { config_dict = {
"train_batch_size": 1, "train_batch_size": 4,
"steps_per_print": 1, "steps_per_print": 1,
"optimizer": { "optimizer": {
"type": "Adam", "type": "Adam",
...@@ -205,37 +212,31 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir): ...@@ -205,37 +212,31 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir):
"lr": 0.00015 "lr": 0.00015
} }
}, },
"scheduler": {
"type": "OneCycle",
"params": {
"cycle_first_step_size": 16000,
"cycle_first_stair_count": 8000,
"decay_step_size": 16000,
"cycle_min_lr": 1e-06,
"cycle_max_lr": 3e-05,
"decay_lr_rate": 1e-07,
"cycle_min_mom": 0.85,
"cycle_max_mom": 0.99,
"decay_mom_rate": 0.0
}
},
"fp16": { "fp16": {
"enabled": True "enabled": True,
"loss_scale": 138.
}, },
"zero_optimization": True "zero_optimization": {
"stage": zero_stage
}
} }
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True) @distributed_test(world_size=2)
def _test_zero_static_scale(args):
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
model, optim, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
@distributed_test(world_size=[1]) # Ensure the static scaler is configured.
def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim): assert optim.dynamic_loss_scale == False
model, _, _,_ = deepspeed.initialize(args=args, assert optim.loss_scaler.loss_scale == 138.
model=model,
model_parameters=model.parameters()) # Now make sure things work..
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
total_samples=50, total_samples=10,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
device=model.device) device=model.device)
for n, batch in enumerate(data_loader): for n, batch in enumerate(data_loader):
...@@ -243,12 +244,10 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir): ...@@ -243,12 +244,10 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir):
model.backward(loss) model.backward(loss)
model.step() model.step()
_test_adam_fp16_zero_onecycle_compatibility(args=args, _test_zero_static_scale(args)
model=model,
hidden_dim=hidden_dim)
def test_zero_static_scale(tmpdir): def test_zero_static_scale_deprecated_format(tmpdir):
config_dict = { config_dict = {
"train_batch_size": 4, "train_batch_size": 4,
"steps_per_print": 1, "steps_per_print": 1,
...@@ -291,14 +290,17 @@ def test_zero_static_scale(tmpdir): ...@@ -291,14 +290,17 @@ def test_zero_static_scale(tmpdir):
_test_zero_static_scale(args) _test_zero_static_scale(args)
def test_zero_allow_untested_optimizer(tmpdir): @pytest.mark.parametrize("zero_stage", [1, 2])
def test_zero_allow_untested_optimizer(tmpdir, zero_stage):
config_dict = { config_dict = {
"train_batch_size": 4, "train_batch_size": 4,
"steps_per_print": 1, "steps_per_print": 1,
"fp16": { "fp16": {
"enabled": True, "enabled": True,
}, },
"zero_optimization": True, "zero_optimization": {
"stage": zero_stage
},
"zero_allow_untested_optimizer": False "zero_allow_untested_optimizer": False
} }
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
...@@ -317,31 +319,34 @@ def test_zero_allow_untested_optimizer(tmpdir): ...@@ -317,31 +319,34 @@ def test_zero_allow_untested_optimizer(tmpdir):
_test_zero_allow_untested_optimizer(args) _test_zero_allow_untested_optimizer(args)
def test_zero_empty_partition(tmpdir): # @pytest.mark.parametrize("zero_stage", [1])
config_dict = { # def test_zero_empty_partition(tmpdir, zero_stage):
"train_batch_size": 3, # config_dict = {
"fp16": { # "train_batch_size": 3,
"enabled": True # "fp16": {
}, # "enabled": True
"optimizer": { # },
"type": "Adam", # "optimizer": {
"params": { # "type": "Adam",
"lr": 0.00015 # "params": {
} # "lr": 0.00015
}, # }
"zero_optimization": True # },
} # "zero_optimization": {
args = args_from_dict(tmpdir, config_dict) # "stage": zero_stage
# }
@distributed_test(world_size=[3]) # }
def _test_zero_empty_partition(args): # args = args_from_dict(tmpdir, config_dict)
hidden_dim = 1
model = SimpleModel(hidden_dim) # @distributed_test(world_size=[3])
# Ensure model has 2 parameters, to cause empty partition with DP=3 # def _test_zero_empty_partition(args):
assert len(list(model.parameters())) == 2 # hidden_dim = 1
model, _, _, _ = deepspeed.initialize(args=args, # model = SimpleModel(hidden_dim)
model=model, # # Ensure model has 2 parameters, to cause empty partition with DP=3
model_parameters=model.parameters()) # assert len(list(model.parameters())) == 2
model.step() # model, _, _, _ = deepspeed.initialize(args=args,
# model=model,
_test_zero_empty_partition(args) # model_parameters=model.parameters())
# model.step()
# _test_zero_empty_partition(args)
...@@ -73,7 +73,7 @@ def test_two_output_model(tmpdir): ...@@ -73,7 +73,7 @@ def test_two_output_model(tmpdir):
summed_loss = sum(loss_tuple) summed_loss = sum(loss_tuple)
scaled_loss = model.backward(summed_loss) scaled_loss = model.backward(summed_loss)
expected_scaled_loss = summed_loss / gradient_accumulation_steps expected_scaled_loss = summed_loss.float() / gradient_accumulation_steps
assert scaled_loss.item() == approx(expected_scaled_loss.item()) assert scaled_loss.item() == approx(expected_scaled_loss.item())
model.step() model.step()
...@@ -131,7 +131,7 @@ def test_three_output_model(tmpdir): ...@@ -131,7 +131,7 @@ def test_three_output_model(tmpdir):
summed_loss = sum(loss_tuple) summed_loss = sum(loss_tuple)
scaled_loss = model.backward(summed_loss) scaled_loss = model.backward(summed_loss)
expected_scaled_loss = summed_loss / gradient_accumulation_steps expected_scaled_loss = summed_loss.float() / gradient_accumulation_steps
assert scaled_loss.item() == approx(expected_scaled_loss.item()) assert scaled_loss.item() == approx(expected_scaled_loss.item())
model.step() model.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment