Unverified Commit 80f434cb authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

add more checks for PytorchBase module (#19)


Co-authored-by: default avatarGuoshuai Zhao <guzhao@microsoft.com>
parent 5d11579a
......@@ -40,7 +40,7 @@ def _init_distributed_setting(self):
if self._args.distributed_impl:
logger.info(
'Distributed training is enabled - model: {}, distributed implementation: {}.'.format(
self._name, self._args.distributed_impl.value
self._name, self._args.distributed_impl
)
)
if self._args.distributed_impl == DistributedImpl.HOROVOD:
......@@ -49,20 +49,20 @@ def _init_distributed_setting(self):
self._world_size = int(hvd.size())
self._local_rank = int(hvd.local_rank())
elif self._args.distributed_impl == DistributedImpl.DDP:
torch.distributed.init_process_group(backend=self._args.distributed_backend.value)
if os.environ.get('WORLD_SIZE') is False or os.environ.get('LOCAL_RANK') is False:
if os.environ.get('WORLD_SIZE') is None or os.environ.get('LOCAL_RANK') is None:
logger.error(
'Can not find WORLD_SIZE or LOCAL_RANK in env variables - model: {},'
' distributed implementation: {}.'.format(self._name, self._args.distributed_impl.value)
' distributed implementation: {}.'.format(self._name, self._args.distributed_impl)
)
return False
torch.distributed.init_process_group(backend=self._args.distributed_backend.value)
self._world_size = int(os.environ['WORLD_SIZE'])
self._local_rank = int(os.environ['LOCAL_RANK'])
else:
logger.error(
'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format(
self._name, self._args.distributed_impl.value
self._name, self._args.distributed_impl
)
)
return False
......@@ -89,14 +89,22 @@ def _init_dataloader(self):
rank=hvd.rank()
)
elif self._args.distributed_impl == DistributedImpl.DDP:
train_sampler = \
torch.utils.data.distributed.DistributedSampler(
self._dataset
try:
train_sampler = \
torch.utils.data.distributed.DistributedSampler(
self._dataset
)
except BaseException as e:
logger.error(
'Init dataloader failed - model: {}, distributed implementation: {}, message: {}.'.format(
self._name, self._args.distributed_impl, str(e)
)
)
return False
else:
logger.error(
'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format(
self._name, self._args.distributed_impl.value
self._name, self._args.distributed_impl
)
)
return False
......@@ -131,12 +139,12 @@ def _create_optimizer(self):
self._optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08)
elif self._optimizer_type == Optimizer.ADAMW:
self._optimizer = torch.optim.AdamW(self._model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08)
else:
self._optimizer = None
if not self._optimizer:
logger.error(
'Create optimizer failed - model: {}, optimizer type: {}.'.format(
self._name, self._optimizer_type.value
)
'Create optimizer failed - model: {}, optimizer type: {}.'.format(self._name, self._optimizer_type)
)
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment