Unverified Commit 7f6deabb authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

add _post_process() implementation in sharding_matmul.py to clean up distributed resource. (#46)


Co-authored-by: default avatarGuoshuai Zhao <guzhao@microsoft.com>
parent 67053d9a
......@@ -14,7 +14,7 @@
import os
import time
# TODO - add mechanism to import torch as needed according to docker
# TODO - add mechanism to import torch as needed according to docker.
import torch
from superbench.common.utils import logger
......@@ -108,6 +108,7 @@ def _preprocess(self):
self.__local_rank = int(os.environ['LOCAL_RANK'])
except BaseException as e:
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
torch.distributed.destroy_process_group()
logger.error(
'Initialize distributed env failed - benchmark: {}, message: {}.'.format(self._name, str(e))
)
......@@ -241,6 +242,29 @@ def _benchmark(self):
return True
def _postprocess(self):
"""Postprocess/cleanup operations after the benchmarking.
Return:
True if _postprocess() succeed.
"""
if not super()._postprocess():
return False
try:
if ShardingMode.ALLGATHER in self._args.mode or ShardingMode.ALLREDUCE in self._args.mode:
torch.distributed.destroy_process_group()
except BaseException as e:
self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_DESTROY_FAILURE)
logger.error(
'Post process failed - benchmark: {}, mode: {}, message: {}.'.format(
self._name, self._args.mode, str(e)
)
)
return False
return True
BenchmarkRegistry.register_benchmark('pytorch-sharding-matmul', ShardingMatmul, parameters='--mode allreduce allgather')
BenchmarkRegistry.register_benchmark('pytorch-matmul', ShardingMatmul, parameters='--mode nosharding')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment