Unverified Commit 1f726091 authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

add _post_process() implementation in pytorch_base.py to clean up distributed resource. (#45)


Co-authored-by: default avatarGuoshuai Zhao <guzhao@microsoft.com>
parent 0172968f
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from superbench.common.utils import logger from superbench.common.utils import logger
from superbench.benchmarks import Framework from superbench.benchmarks import Framework, ReturnCode
from superbench.benchmarks.model_benchmarks.model_base import Optimizer, DistributedImpl, ModelBenchmark from superbench.benchmarks.model_benchmarks.model_base import Optimizer, DistributedImpl, ModelBenchmark
...@@ -161,6 +161,29 @@ def _create_optimizer(self): ...@@ -161,6 +161,29 @@ def _create_optimizer(self):
return True return True
def _postprocess(self):
"""Postprocess/cleanup operations after the benchmarking.
Return:
True if _postprocess() succeed.
"""
if not super()._postprocess():
return False
try:
if self._args.distributed_impl == DistributedImpl.DDP:
torch.distributed.destroy_process_group()
except BaseException as e:
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_DESTROY_FAILURE)
logger.error(
'Post process failed - model: {}, distributed implementation: {}, message: {}.'.format(
self._name, self._args.distributed_impl, str(e)
)
)
return False
return True
def _cal_params_count(self): def _cal_params_count(self):
"""Calculate the parameters scale of the model. """Calculate the parameters scale of the model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment