"git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "2d9be807a9690fff140f0e8ba9cbd297edd6d502"
Unverified Commit e31b8c9e authored by user4543's avatar user4543 Committed by GitHub
Browse files

Benchmarks: Revise Code - Add support for pytorch>=1.9.0 of init_process_group (#305)

**Description**
Add support for pytorch>=1.9.0 of init_process_group.

**Major Revision**
- Use PrefixStore(TCPStore) to init_process_group manully for each model run
parent 4abda6f5
...@@ -4,10 +4,12 @@ ...@@ -4,10 +4,12 @@
"""Module of the Pytorch model-benchmark base class.""" """Module of the Pytorch model-benchmark base class."""
import os import os
from datetime import timedelta
import torch import torch
import transformers import transformers
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.distributed import TCPStore, PrefixStore
from superbench.common.utils import logger from superbench.common.utils import logger
from superbench.benchmarks import Framework, ReturnCode, DistributedBackend, DistributedImpl from superbench.benchmarks import Framework, ReturnCode, DistributedBackend, DistributedImpl
...@@ -65,10 +67,24 @@ def _init_distributed_setting(self): ...@@ -65,10 +67,24 @@ def _init_distributed_setting(self):
' distributed implementation: {}.'.format(self._name, self._args.distributed_impl) ' distributed implementation: {}.'.format(self._name, self._args.distributed_impl)
) )
return False return False
# torch >= 1.9.0a0 torch.distributed.elastic is used by default
torch.distributed.init_process_group(backend=self._args.distributed_backend.value) port = int(os.environ['MASTER_PORT']) + 1
self._world_size = int(os.environ['WORLD_SIZE']) addr = os.environ['MASTER_ADDR']
global_rank = int(os.environ['RANK'])
self._local_rank = int(os.environ['LOCAL_RANK']) self._local_rank = int(os.environ['LOCAL_RANK'])
self._world_size = int(os.environ['WORLD_SIZE'])
logger.debug('ip:{},port:{},rank:{},world:{}'.format(addr, port, global_rank, self._world_size))
store = PrefixStore(
self._name, TCPStore(addr, port, self._world_size, global_rank == 0, timedelta(seconds=300))
)
torch.distributed.init_process_group(
backend=self._args.distributed_backend.value,
timeout=timedelta(seconds=300),
rank=global_rank,
world_size=self._world_size,
store=store
)
else: else:
logger.error( logger.error(
'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format( 'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment