Unverified Commit af81f6f5 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Handle missing optional configuration fields correctly (#24)


Co-authored-by: default avatarShaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 54940278
...@@ -256,7 +256,8 @@ class DeepSpeedConfig(object): ...@@ -256,7 +256,8 @@ class DeepSpeedConfig(object):
self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict) self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict)
self.optimizer_name = get_optimizer_name(param_dict) self.optimizer_name = get_optimizer_name(param_dict)
if self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS: if self.optimizer_name is not None and \
self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS:
self.optimizer_name = self.optimizer_name.lower() self.optimizer_name = self.optimizer_name.lower()
self.optimizer_params = get_optimizer_params(param_dict) self.optimizer_params = get_optimizer_params(param_dict)
......
import pytest
import os
import json
from deepspeed.pt import deepspeed_config as ds_config
def test_only_required_fields(tmpdir):
'''Ensure that config containing only the required fields is accepted. '''
cfg_json = tmpdir.mkdir('ds_config_unit_test').join('minimal.json')
with open(cfg_json, 'w') as f:
required_fields = {'train_batch_size': 64}
json.dump(required_fields, f)
run_cfg = ds_config.DeepSpeedConfig(cfg_json)
assert run_cfg is not None
assert run_cfg.train_batch_size == 64
assert run_cfg.train_micro_batch_size_per_gpu == 64
assert run_cfg.gradient_accumulation_steps == 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment