Unverified Commit 6021b702 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Support non-tensor state in checkpoint (#548)

parent 0178e6cc
...@@ -947,9 +947,10 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -947,9 +947,10 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
state_key, state_key,
all_partition_states, all_partition_states,
max_elems_per_comm): max_elems_per_comm):
partition_id = dist.get_rank(group=self.dp_process_group) if not torch.is_tensor(all_partition_states[0]):
alignment = dist.get_world_size(group=self.dp_process_group) return all_partition_states[0]
alignment = dist.get_world_size(group=self.dp_process_group)
flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned( flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_partition_states, tensor_list=all_partition_states,
dp=dist.get_world_size(group=self.dp_process_group), dp=dist.get_world_size(group=self.dp_process_group),
...@@ -964,6 +965,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -964,6 +965,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
dp_process_group=self.dp_process_group dp_process_group=self.dp_process_group
) )
partition_id = dist.get_rank(group=self.dp_process_group)
return [sub_partition for sub_partition in dp_sub_partitions[partition_id]] return [sub_partition for sub_partition in dp_sub_partitions[partition_id]]
# Compute the optimizer state partitions for the group by # Compute the optimizer state partitions for the group by
...@@ -1013,8 +1015,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -1013,8 +1015,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
for group_idx, group in enumerate(self.optimizer.param_groups): for group_idx, group in enumerate(self.optimizer.param_groups):
for param_idx, param in enumerate(group['params']): for param_idx, param in enumerate(group['params']):
for key, saved in base_optimizer_group_states[group_idx].items(): for key, saved in base_optimizer_group_states[group_idx].items():
current = self.optimizer.state[param][key] if torch.is_tensor(self.optimizer.state[param][key]):
current.data.copy_(saved[param_idx].data) current = self.optimizer.state[param][key]
current.data.copy_(saved[param_idx].data)
else:
self.optimizer.state[param][key] = saved
# Restore base optimizer fp32 weights from ZeRO fp16 weights # Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self): def _restore_from_fp16_weights(self):
......
...@@ -101,6 +101,37 @@ class SimpleOptimizer(torch.optim.Optimizer): ...@@ -101,6 +101,37 @@ class SimpleOptimizer(torch.optim.Optimizer):
return loss return loss
class HybridStateOptimizer(torch.optim.Optimizer):
def __init__(self, params, lr=0.11072018):
defaults = dict(lr=lr)
super(HybridStateOptimizer, self).__init__(params, defaults)
def __setstate__(self, state):
super(HybridStateOptimizer, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
state['integer_step'] = 0
state['tensor_step'] = torch.zeros(1)
d_p = p.grad.data
p.data.add_(-group['lr'], d_p)
state['integer_step'] += 1
state['tensor_step'] += 1
return loss
class PLD_SimpleModel(SimpleModel): class PLD_SimpleModel(SimpleModel):
def __init__(self, hidden_dim, empty_grad=False, rank=0): def __init__(self, hidden_dim, empty_grad=False, rank=0):
super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad, rank) super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad, rank)
......
...@@ -36,6 +36,7 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True): ...@@ -36,6 +36,7 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
compare_deepspeed_states(saved_model, loaded_model) compare_deepspeed_states(saved_model, loaded_model)
for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()): for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()):
assert id(p0) != id(p1), f'Comparing fp16 model state tensor against itself : {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"FP16 model state {p0} is not equal to {p1}" assert torch.allclose(p0, p1, atol=1e-07), f"FP16 model state {p0} is not equal to {p1}"
if not compare_optimizer: if not compare_optimizer:
...@@ -43,20 +44,24 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True): ...@@ -43,20 +44,24 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
if isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer): if isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer):
for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups): for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}" assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1): elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups): for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
for p0, p1 in zip(partition0, partition1): for p0, p1 in zip(partition0, partition1):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}" assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_Optimizer): elif isinstance(saved_model.optimizer, FP16_Optimizer):
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat): for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}" assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer): elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer):
for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups): for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups):
for p0, p1 in zip(params0, params1): for p0, p1 in zip(params0, params1):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}" assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, torch.optim.Optimizer): elif isinstance(saved_model.optimizer, torch.optim.Optimizer):
pass pass
...@@ -72,6 +77,7 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True): ...@@ -72,6 +77,7 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True):
loaded_optimizer.state.values()): loaded_optimizer.state.values()):
for s0, s1 in zip(state0.values(), state1.values()): for s0, s1 in zip(state0.values(), state1.values()):
if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor): if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}'
assert torch.equal(s0, s1) assert torch.equal(s0, s1)
else: else:
assert s0 == s1 assert s0 == s1
...@@ -100,18 +106,34 @@ def compare_lr_scheduler_states(saved_model, loaded_model): ...@@ -100,18 +106,34 @@ def compare_lr_scheduler_states(saved_model, loaded_model):
assert state0 == state1 assert state0 == state1
def create_deepspeed_model(args, model, base_optimizer):
if base_optimizer is None:
ds_model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
else:
ds_model, _, _, _ = deepspeed.initialize(args=args,
model=model,
optimizer=base_optimizer)
return ds_model
def checkpoint_correctness_verification(args, def checkpoint_correctness_verification(args,
model, models,
hidden_dim, hidden_dim,
tmpdir, tmpdir,
load_optimizer_states=False, load_optimizer_states=False,
load_lr_scheduler_states=False, load_lr_scheduler_states=False,
fp16=True, fp16=True,
train_batch=False): train_batch=False,
base_optimizers=[None,
None]):
dtype = torch.half if fp16 else torch.float32 dtype = torch.half if fp16 else torch.float32
ds_model, _, _, _ = deepspeed.initialize(args=args, ds_model = create_deepspeed_model(args=args,
model=model, model=models[0],
model_parameters=model.parameters()) base_optimizer=base_optimizers[0])
data_loader = random_dataloader(model=ds_model, data_loader = random_dataloader(model=ds_model,
total_samples=50, total_samples=50,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
...@@ -125,7 +147,6 @@ def checkpoint_correctness_verification(args, ...@@ -125,7 +147,6 @@ def checkpoint_correctness_verification(args,
else: else:
for n, batch in enumerate(data_loader): for n, batch in enumerate(data_loader):
loss = ds_model(batch[0], batch[1]) loss = ds_model(batch[0], batch[1])
print(loss)
ds_model.backward(loss) ds_model.backward(loss)
ds_model.step() ds_model.step()
...@@ -136,9 +157,9 @@ def checkpoint_correctness_verification(args, ...@@ -136,9 +157,9 @@ def checkpoint_correctness_verification(args,
trained_model.save_checkpoint(save_folder, save_tag) trained_model.save_checkpoint(save_folder, save_tag)
loaded_model, _, _, _ = deepspeed.initialize(args=args, loaded_model = create_deepspeed_model(args=args,
model=model, model=models[1],
model_parameters=model.parameters()) base_optimizer=base_optimizers[1])
loaded_model.load_checkpoint(save_folder, loaded_model.load_checkpoint(save_folder,
save_tag, save_tag,
...@@ -191,25 +212,26 @@ def test_checkpoint_unfused_optimizer(tmpdir): ...@@ -191,25 +212,26 @@ def test_checkpoint_unfused_optimizer(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_checkpoint_unfused_optimizer(args, def _test_checkpoint_unfused_optimizer(args,
model, models,
hidden_dim, hidden_dim,
load_optimizer_states): load_optimizer_states):
checkpoint_correctness_verification(args, checkpoint_correctness_verification(args,
model, models=models,
hidden_dim, hidden_dim=hidden_dim,
tmpdir, tmpdir=tmpdir,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states)
_test_checkpoint_unfused_optimizer(args=args, _test_checkpoint_unfused_optimizer(args=args,
model=model, models=models,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
load_optimizer_states=True) load_optimizer_states=True)
_test_checkpoint_unfused_optimizer(args=args, _test_checkpoint_unfused_optimizer(args=args,
model=model, models=models,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
load_optimizer_states=False) load_optimizer_states=False)
...@@ -236,22 +258,26 @@ def test_checkpoint_fused_optimizer(tmpdir): ...@@ -236,22 +258,26 @@ def test_checkpoint_fused_optimizer(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states): def _test_checkpoint_fused_optimizer(args,
models,
hidden_dim,
load_optimizer_states):
checkpoint_correctness_verification(args, checkpoint_correctness_verification(args,
model, models=models,
hidden_dim, hidden_dim=hidden_dim,
tmpdir, tmpdir=tmpdir,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states)
_test_checkpoint_fused_optimizer(args=args, _test_checkpoint_fused_optimizer(args=args,
model=model, models=models,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
load_optimizer_states=True) load_optimizer_states=True)
_test_checkpoint_fused_optimizer(args=args, _test_checkpoint_fused_optimizer(args=args,
model=model, models=models,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
load_optimizer_states=False) load_optimizer_states=False)
...@@ -293,18 +319,18 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload): ...@@ -293,18 +319,18 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states): def _test_checkpoint_zero_optimizer(args, models, hidden_dim, load_optimizer_states):
checkpoint_correctness_verification(args, checkpoint_correctness_verification(args,
model, models=models,
hidden_dim, hidden_dim=hidden_dim,
tmpdir, tmpdir=tmpdir,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states)
_test_checkpoint_zero_optimizer(args=args, _test_checkpoint_zero_optimizer(args=args,
model=model, models=models,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
load_optimizer_states=True) load_optimizer_states=True)
...@@ -346,21 +372,21 @@ def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload): ...@@ -346,21 +372,21 @@ def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_checkpoint_zero_no_optimizer(args, def _test_checkpoint_zero_no_optimizer(args,
model, models,
hidden_dim, hidden_dim,
load_optimizer_states): load_optimizer_states):
checkpoint_correctness_verification(args, checkpoint_correctness_verification(args,
model, models=models,
hidden_dim, hidden_dim=hidden_dim,
tmpdir, tmpdir=tmpdir,
load_optimizer_states=load_optimizer_states) load_optimizer_states=load_optimizer_states)
_test_checkpoint_zero_no_optimizer(args=args, _test_checkpoint_zero_no_optimizer(args=args,
model=model, models=models,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
load_optimizer_states=False) load_optimizer_states=False)
...@@ -412,24 +438,24 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload): ...@@ -412,24 +438,24 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_checkpoint_lr_scheduler(args, def _test_checkpoint_lr_scheduler(args,
model, models,
hidden_dim, hidden_dim,
load_optimizer_states, load_optimizer_states,
load_lr_scheduler_states): load_lr_scheduler_states):
checkpoint_correctness_verification( checkpoint_correctness_verification(
args, args,
model, models=models,
hidden_dim, hidden_dim=hidden_dim,
tmpdir, tmpdir=tmpdir,
load_optimizer_states=load_optimizer_states, load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states) load_lr_scheduler_states=load_lr_scheduler_states)
_test_checkpoint_lr_scheduler(args=args, _test_checkpoint_lr_scheduler(args=args,
model=model, models=models,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
load_optimizer_states=False, load_optimizer_states=False,
load_lr_scheduler_states=True) load_lr_scheduler_states=True)
...@@ -478,24 +504,24 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload): ...@@ -478,24 +504,24 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_checkpoint_no_lr_scheduler(args, def _test_checkpoint_no_lr_scheduler(args,
model, models,
hidden_dim, hidden_dim,
load_optimizer_states, load_optimizer_states,
load_lr_scheduler_states): load_lr_scheduler_states):
checkpoint_correctness_verification( checkpoint_correctness_verification(
args, args,
model, models=models,
hidden_dim, hidden_dim=hidden_dim,
tmpdir, tmpdir=tmpdir,
load_optimizer_states=load_optimizer_states, load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states) load_lr_scheduler_states=load_lr_scheduler_states)
_test_checkpoint_no_lr_scheduler(args=args, _test_checkpoint_no_lr_scheduler(args=args,
model=model, models=models,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
load_optimizer_states=False, load_optimizer_states=False,
load_lr_scheduler_states=False) load_lr_scheduler_states=False)
...@@ -523,13 +549,17 @@ def test_checkpoint_fp32_optimizer(tmpdir): ...@@ -523,13 +549,17 @@ def test_checkpoint_fp32_optimizer(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_checkpoint_fp32_optimizer(args, model, hidden_dim): def _test_checkpoint_fp32_optimizer(args, models, hidden_dim):
checkpoint_correctness_verification(args, model, hidden_dim, tmpdir, fp16=False) checkpoint_correctness_verification(args,
models=models,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
fp16=False)
_test_checkpoint_fp32_optimizer(args=args, model=model, hidden_dim=hidden_dim) _test_checkpoint_fp32_optimizer(args=args, models=models, hidden_dim=hidden_dim)
@pytest.mark.parametrize("zero_stage", [0, 1]) @pytest.mark.parametrize("zero_stage", [0, 1])
...@@ -571,10 +601,10 @@ def test_checkpoint_pipe_engine(zero_stage, tmpdir, stages=2): ...@@ -571,10 +601,10 @@ def test_checkpoint_pipe_engine(zero_stage, tmpdir, stages=2):
@distributed_test(world_size=4) @distributed_test(world_size=4)
def _test(save_folder, num_stages): def _test(save_folder, num_stages):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
model = LinearStackPipe(num_stages=num_stages) models = [LinearStackPipe(num_stages=num_stages) for _ in range(2)]
checkpoint_correctness_verification(args=args, checkpoint_correctness_verification(args=args,
model=model, models=models,
hidden_dim=model.hidden_dim, hidden_dim=models[0].hidden_dim,
tmpdir=save_folder, tmpdir=save_folder,
fp16=config_dict['fp16']['enabled'], fp16=config_dict['fp16']['enabled'],
load_optimizer_states=True, load_optimizer_states=True,
...@@ -635,3 +665,42 @@ def test_checkpoint_pipe_module(base_topo, test_topo, tmpdir): ...@@ -635,3 +665,42 @@ def test_checkpoint_pipe_module(base_topo, test_topo, tmpdir):
assert torch.allclose(p0, p1, atol=1e-07), f"Model state {p0} is not equal to {p1}" assert torch.allclose(p0, p1, atol=1e-07), f"Model state {p0} is not equal to {p1}"
_test(base_topo, test_topo, save_folder=tmpdir) _test(base_topo, test_topo, save_folder=tmpdir)
@pytest.mark.parametrize('zero_stage', [1, 2])
def test_checkpoint_zero_hybrid_optimizer_state(tmpdir, zero_stage):
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 2,
"steps_per_print": 1,
"zero_optimization": {
"stage": zero_stage
},
"zero_allow_untested_optimizer": True,
"fp16": {
"enabled": True,
"initial_scale_power": 8
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)]
optimizers = [HybridStateOptimizer(model.parameters()) for model in models]
@distributed_test(world_size=[2])
def _test_checkpoint_zero_hybrid_optimizer_state(args,
models,
optimizers,
hidden_dim):
checkpoint_correctness_verification(args,
models=models,
base_optimizers=optimizers,
hidden_dim=hidden_dim,
tmpdir=tmpdir,
load_optimizer_states=True)
_test_checkpoint_zero_hybrid_optimizer_state(args=args,
models=models,
optimizers=optimizers,
hidden_dim=hidden_dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment