"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "4375c2d7e42d463f0585a7271f83b6d2d313cd14"
Unverified Commit a9a83a6f authored by gcooper-isi's avatar gcooper-isi Committed by GitHub
Browse files

Allow DeepSpeed models to be initialized with optimizer=None (#469)



Allow DeepSpeed models to be initialized with optimizer=None
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
parent e6ac7311
...@@ -466,10 +466,9 @@ class DeepSpeedEngine(Module): ...@@ -466,10 +466,9 @@ class DeepSpeedEngine(Module):
# Validate configuration based on command line arguments # Validate configuration based on command line arguments
def _do_sanity_check(self): def _do_sanity_check(self):
if not self.client_optimizer: if not self.client_optimizer:
assert self._is_supported_optimizer(self.optimizer_name()), \ if self.optimizer_name() is not None:
'{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name()) assert self._is_supported_optimizer(self.optimizer_name()), \
assert self.client_model_parameters, \ '{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name())
'DeepSpeed {} optimizer requires parameters in initialize() call'.format(self.optimizer_name())
if self.optimizer_name() == LAMB_OPTIMIZER: if self.optimizer_name() == LAMB_OPTIMIZER:
assert self.dynamic_loss_scale(), \ assert self.dynamic_loss_scale(), \
...@@ -1289,7 +1288,7 @@ class DeepSpeedEngine(Module): ...@@ -1289,7 +1288,7 @@ class DeepSpeedEngine(Module):
self.load_module_state_dict(state_dict=checkpoint['module'], self.load_module_state_dict(state_dict=checkpoint['module'],
strict=load_module_strict) strict=load_module_strict)
if not self.zero_optimization(): if self.optimizer is not None and not self.zero_optimization():
if self.fp16_enabled(): if self.fp16_enabled():
self.optimizer.load_state_dict( self.optimizer.load_state_dict(
checkpoint['optimizer'], checkpoint['optimizer'],
......
...@@ -195,3 +195,34 @@ def test_dist_init_true(tmpdir): ...@@ -195,3 +195,34 @@ def test_dist_init_true(tmpdir):
model.step() model.step()
_test_dist_init_true(args=args, model=model, hidden_dim=hidden_dim) _test_dist_init_true(args=args, model=model, hidden_dim=hidden_dim)
def test_init_no_optimizer(tmpdir):
config_dict = {"train_batch_size": 1, "fp16": {"enabled": True}}
config_path = create_config_from_dict(tmpdir, config_dict)
@distributed_test(world_size=1)
def _helper():
parser = argparse.ArgumentParser()
args = parser.parse_args(args='')
args.deepscale_config = config_path
args.local_rank = 0
hidden_dim = 10
model = SimpleModel(hidden_dim=hidden_dim)
model, _, _, _ = deepspeed.initialize(args=args, model=model)
data_loader = random_dataloader(model=model,
total_samples=5,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
with pytest.raises(AssertionError):
model.backward(loss)
with pytest.raises(AssertionError):
model.step()
_helper()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment