Unverified Commit a9a83a6f authored by gcooper-isi's avatar gcooper-isi Committed by GitHub
Browse files

Allow DeepSpeed models to be initialized with optimizer=None (#469)



Allow DeepSpeed models to be initialized with optimizer=None
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
parent e6ac7311
......@@ -466,10 +466,9 @@ class DeepSpeedEngine(Module):
# Validate configuration based on command line arguments
def _do_sanity_check(self):
if not self.client_optimizer:
assert self._is_supported_optimizer(self.optimizer_name()), \
'{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name())
assert self.client_model_parameters, \
'DeepSpeed {} optimizer requires parameters in initialize() call'.format(self.optimizer_name())
if self.optimizer_name() is not None:
assert self._is_supported_optimizer(self.optimizer_name()), \
'{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name())
if self.optimizer_name() == LAMB_OPTIMIZER:
assert self.dynamic_loss_scale(), \
......@@ -1289,7 +1288,7 @@ class DeepSpeedEngine(Module):
self.load_module_state_dict(state_dict=checkpoint['module'],
strict=load_module_strict)
if not self.zero_optimization():
if self.optimizer is not None and not self.zero_optimization():
if self.fp16_enabled():
self.optimizer.load_state_dict(
checkpoint['optimizer'],
......
......@@ -195,3 +195,34 @@ def test_dist_init_true(tmpdir):
model.step()
_test_dist_init_true(args=args, model=model, hidden_dim=hidden_dim)
def test_init_no_optimizer(tmpdir):
config_dict = {"train_batch_size": 1, "fp16": {"enabled": True}}
config_path = create_config_from_dict(tmpdir, config_dict)
@distributed_test(world_size=1)
def _helper():
parser = argparse.ArgumentParser()
args = parser.parse_args(args='')
args.deepscale_config = config_path
args.local_rank = 0
hidden_dim = 10
model = SimpleModel(hidden_dim=hidden_dim)
model, _, _, _ = deepspeed.initialize(args=args, model=model)
data_loader = random_dataloader(model=model,
total_samples=5,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
with pytest.raises(AssertionError):
model.backward(loss)
with pytest.raises(AssertionError):
model.step()
_helper()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment