import torch import os def init_ddp(visiable_devices='0,1,2,3'): if torch.cuda.device_count() > 1: os.environ['HIP_VISIBLE_DEVICES'] = visiable_devices local_rank = int(os.environ["LOCAL_RANK"]) print("local_rank:" + str(local_rank)) #torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=0, world_size=1) torch.distributed.init_process_group(backend="nccl") # local_rank = torch.distributed.get_rank() torch.cuda.set_device(local_rank) # device = torch.device("cuda", args.local_rank) return local_rank else: return None