Unverified Commit 8326aff2 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Improve doc string for add_XXX_arguments (#32)

Unit tests for add_XXX_arguments
parent af81f6f5
...@@ -37,7 +37,7 @@ def initialize(args, ...@@ -37,7 +37,7 @@ def initialize(args,
optimizer: Optional: a user defined optimizer, this is typically used instead of defining optimizer: Optional: a user defined optimizer, this is typically used instead of defining
an optimizer in the DeepSpeed json config. an optimizer in the DeepSpeed json config.
model_parameters: Optional: An iterable of torch.Tensor s or dicts. model_parameters: Optional: An iterable of torch.Tensors or dicts.
Specifies what Tensors should be optimized. Specifies what Tensors should be optimized.
training_data: Optional: Dataset of type torch.utils.data.Dataset training_data: Optional: Dataset of type torch.utils.data.Dataset
...@@ -55,8 +55,20 @@ def initialize(args, ...@@ -55,8 +55,20 @@ def initialize(args,
map-style dataset. map-style dataset.
Return: Return:
The following tuple is returned by this function.
tuple: engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler tuple: engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler
engine: DeepSpeed runtime engine which wraps the client model for distributed training.
engine.optimizer: Wrapped optimizer if a user defined optimizer is passed or
if optimizer is specified in json config else None.
engine.training_dataloader: DeepSpeed dataloader if training data was passed else None.
engine.lr_scheduler: Wrapped lr scheduler if user lr scheduler is passed
or if lr scheduler specified in json config else None.
""" """
print("DeepSpeed info: version={}, git-hash={}, git-branch={}".format( print("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(
__version__, __version__,
...@@ -83,8 +95,13 @@ def initialize(args, ...@@ -83,8 +95,13 @@ def initialize(args,
return tuple(return_items) return tuple(return_items)
def add_core_arguments(parser): def _add_core_arguments(parser):
r"""Adds argument group for enabling deepspeed and providing deepspeed config file r"""Helper (internal) function to update an argument parser with an argument group of the core DeepSpeed arguments.
The core set of DeepSpeed arguments include the following:
1) --deepspeed: boolean flag to enable DeepSpeed
2) --deepspeed_config <json file path>: path of a json configuration file to configure DeepSpeed runtime.
This is a helper function to the public add_config_arguments()
Arguments: Arguments:
parser: argument parser parser: argument parser
...@@ -107,13 +124,16 @@ def add_core_arguments(parser): ...@@ -107,13 +124,16 @@ def add_core_arguments(parser):
def add_config_arguments(parser): def add_config_arguments(parser):
r"""Updates the parser to parse DeepSpeed arguments r"""Update the argument parser to enabling parsing of DeepSpeed command line arguments.
The set of DeepSpeed arguments include the following:
1) --deepspeed: boolean flag to enable DeepSpeed
2) --deepspeed_config <json file path>: path of a json configuration file to configure DeepSpeed runtime.
Arguments: Arguments:
parser: argument parser parser: argument parser
Return: Return:
parser: Updated Parser parser: Updated Parser
""" """
parser = add_core_arguments(parser) parser = _add_core_arguments(parser)
return parser return parser
import argparse
import pytest
import deepspeed
def basic_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', type=int)
return parser
def test_no_ds_arguments_no_ds_parser():
parser = basic_parser()
args = parser.parse_args(['--num_epochs', '2'])
assert args
assert hasattr(args, 'num_epochs')
assert args.num_epochs == 2
assert not hasattr(args, 'deepspeed')
assert not hasattr(args, 'deepspeed_config')
def test_no_ds_arguments():
parser = basic_parser()
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args(['--num_epochs', '2'])
assert args
assert hasattr(args, 'num_epochs')
assert args.num_epochs == 2
assert hasattr(args, 'deepspeed')
assert args.deepspeed == False
assert hasattr(args, 'deepspeed_config')
assert args.deepspeed_config == None
def test_no_ds_enable_argument():
parser = basic_parser()
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args(['--num_epochs', '2', '--deepspeed_config', 'foo.json'])
assert args
assert hasattr(args, 'num_epochs')
assert args.num_epochs == 2
assert hasattr(args, 'deepspeed')
assert args.deepspeed == False
assert hasattr(args, 'deepspeed_config')
assert type(args.deepspeed_config) == str
assert args.deepspeed_config == 'foo.json'
def test_no_ds_config_argument():
parser = basic_parser()
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args(['--num_epochs', '2', '--deepspeed'])
assert args
assert hasattr(args, 'num_epochs')
assert args.num_epochs == 2
assert hasattr(args, 'deepspeed')
assert type(args.deepspeed) == bool
assert args.deepspeed == True
assert hasattr(args, 'deepspeed_config')
assert args.deepspeed_config == None
def test_no_ds_parser():
parser = basic_parser()
with pytest.raises(SystemExit):
args = parser.parse_args(['--num_epochs', '2', '--deepspeed'])
def test_core_deepscale_arguments():
parser = basic_parser()
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args(
['--num_epochs',
'2',
'--deepspeed',
'--deepspeed_config',
'foo.json'])
assert args
assert hasattr(args, 'num_epochs')
assert args.num_epochs == 2
assert hasattr(args, 'deepspeed')
assert type(args.deepspeed) == bool
assert args.deepspeed == True
assert hasattr(args, 'deepspeed_config')
assert type(args.deepspeed_config) == str
assert args.deepspeed_config == 'foo.json'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment