test.py 603 Bytes
Newer Older
shanmugamr's avatar
shanmugamr committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import os
import torch

def main():
    rank = torch.cuda.current_device()
    world_size = torch.cuda.device_count()
    print(f'Initializing torch.distributed with rank: {rank}, world_size: {world_size}')
    torch.cuda.set_device(rank % torch.cuda.device_count())
    init_method = 'tcp://'
    master_ip = os.getenv('MASTER_ADDR', 'localhost')
    master_port = os.getenv('MASTER_PORT', '6000')
    init_method += master_ip + ':' + master_port
    torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank, init_method=init_method)

if __name__ == '__main__':
    main()