"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d185c0dfa7904210e099dbbb45f02ad9c40d7904"
Unverified Commit e6ac7311 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Support initialization with dict configuration (#632)

parent 24e07399
...@@ -435,9 +435,9 @@ class DeepSpeedEngine(Module): ...@@ -435,9 +435,9 @@ class DeepSpeedEngine(Module):
# Configure based on command line arguments # Configure based on command line arguments
def _configure_with_arguments(self, args, mpu): def _configure_with_arguments(self, args, mpu):
self.local_rank = args.local_rank if hasattr(args, 'local_rank') else 0 self.local_rank = args.local_rank if hasattr(args, 'local_rank') else 0
self._config = DeepSpeedConfig(args.deepspeed_config, config_file = args.deepspeed_config if hasattr(args,
mpu, 'deepspeed_config') else None
param_dict=self.config_params) self._config = DeepSpeedConfig(config_file, mpu, param_dict=self.config_params)
# Validate command line arguments # Validate command line arguments
def _do_args_sanity_check(self, args): def _do_args_sanity_check(self, args):
......
...@@ -161,12 +161,10 @@ def create_config_from_dict(tmpdir, config_dict): ...@@ -161,12 +161,10 @@ def create_config_from_dict(tmpdir, config_dict):
return config_path return config_path
def args_from_dict(tmpdir, config_dict): def create_deepspeed_args():
config_path = create_config_from_dict(tmpdir, config_dict)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
args = parser.parse_args(args='') args = parser.parse_args(args='')
args.deepspeed = True args.deepspeed = True
args.deepspeed_config = config_path
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
# We assume up to one full node executing unit tests # We assume up to one full node executing unit tests
assert torch.distributed.get_world_size() <= torch.cuda.device_count() assert torch.distributed.get_world_size() <= torch.cuda.device_count()
...@@ -174,3 +172,10 @@ def args_from_dict(tmpdir, config_dict): ...@@ -174,3 +172,10 @@ def args_from_dict(tmpdir, config_dict):
else: else:
args.local_rank = 0 args.local_rank = 0
return args return args
def args_from_dict(tmpdir, config_dict):
args = create_deepspeed_args()
config_path = create_config_from_dict(tmpdir, config_dict)
args.deepspeed_config = config_path
return args
...@@ -6,7 +6,7 @@ import json ...@@ -6,7 +6,7 @@ import json
import os import os
from deepspeed.ops.adam import FusedAdam from deepspeed.ops.adam import FusedAdam
from common import distributed_test from common import distributed_test
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args
try: try:
from apex import amp from apex import amp
...@@ -194,6 +194,41 @@ def test_adamw_fp16_basic(tmpdir): ...@@ -194,6 +194,41 @@ def test_adamw_fp16_basic(tmpdir):
_test_adamw_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) _test_adamw_fp16_basic(args=args, model=model, hidden_dim=hidden_dim)
def test_dict_config_adamw_fp16_basic():
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"fp16": {
"enabled": True
}
}
args = create_deepspeed_args()
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_adamw_fp16_basic(args, model, hidden_dim, config_dict):
optimizer = torch.optim.AdamW(params=model.parameters())
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
optimizer=optimizer,
config_params=config_dict)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_adamw_fp16_basic(args=args,
model=model,
hidden_dim=hidden_dim,
config_dict=config_dict)
def test_adamw_fp16_empty_grad(tmpdir): def test_adamw_fp16_empty_grad(tmpdir):
config_dict = { config_dict = {
"train_batch_size": 1, "train_batch_size": 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment