DDP.py 653 Bytes
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import os

def init_ddp(visiable_devices='0,1,2,3'):
    if torch.cuda.device_count() > 1:
        os.environ['HIP_VISIBLE_DEVICES'] = visiable_devices
        local_rank = int(os.environ["LOCAL_RANK"])
        print("local_rank:" + str(local_rank))
        #torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=0, world_size=1)
        torch.distributed.init_process_group(backend="nccl")

        # local_rank = torch.distributed.get_rank()
        torch.cuda.set_device(local_rank)
        # device = torch.device("cuda", args.local_rank)
        return local_rank
    else:
        return None