Unverified Commit e31b8c9e authored by user4543's avatar user4543 Committed by GitHub
Browse files

Benchmarks: Revise Code - Add support for pytorch>=1.9.0 of init_process_group (#305)

**Description**
Add support for pytorch>=1.9.0 of init_process_group.

**Major Revision**
- Use PrefixStore(TCPStore) to init_process_group manully for each model run
parent 4abda6f5
......@@ -4,10 +4,12 @@
"""Module of the Pytorch model-benchmark base class."""
import os
from datetime import timedelta
import torch
import transformers
from torch.utils.data import DataLoader
from torch.distributed import TCPStore, PrefixStore
from superbench.common.utils import logger
from superbench.benchmarks import Framework, ReturnCode, DistributedBackend, DistributedImpl
......@@ -65,10 +67,24 @@ def _init_distributed_setting(self):
' distributed implementation: {}.'.format(self._name, self._args.distributed_impl)
)
return False
torch.distributed.init_process_group(backend=self._args.distributed_backend.value)
self._world_size = int(os.environ['WORLD_SIZE'])
# torch >= 1.9.0a0 torch.distributed.elastic is used by default
port = int(os.environ['MASTER_PORT']) + 1
addr = os.environ['MASTER_ADDR']
global_rank = int(os.environ['RANK'])
self._local_rank = int(os.environ['LOCAL_RANK'])
self._world_size = int(os.environ['WORLD_SIZE'])
logger.debug('ip:{},port:{},rank:{},world:{}'.format(addr, port, global_rank, self._world_size))
store = PrefixStore(
self._name, TCPStore(addr, port, self._world_size, global_rank == 0, timedelta(seconds=300))
)
torch.distributed.init_process_group(
backend=self._args.distributed_backend.value,
timeout=timedelta(seconds=300),
rank=global_rank,
world_size=self._world_size,
store=store
)
else:
logger.error(
'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment