"...git@developer.sourcefind.cn:2222/OpenDAS/flashmla.git" did not exist on "76bb5d10cee439a8c6ca3ae5f53463c955cd8822"
Commit bf4a4a21 authored by Shangyan Zhou's avatar Shangyan Zhou
Browse files

Set `device_id` to suppress pytorch warning.

parent 77f97f79
...@@ -14,12 +14,17 @@ def init_dist(local_rank: int, num_local_ranks: int): ...@@ -14,12 +14,17 @@ def init_dist(local_rank: int, num_local_ranks: int):
node_rank = int(os.getenv('RANK', 0)) node_rank = int(os.getenv('RANK', 0))
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
dist.init_process_group( import inspect
backend='nccl', sig = inspect.signature(dist.init_process_group)
init_method=f'tcp://{ip}:{port}', params = {
world_size=num_nodes * num_local_ranks, 'backend': 'nccl',
rank=node_rank * num_local_ranks + local_rank 'init_method': f'tcp://{ip}:{port}',
) 'world_size': num_nodes * num_local_ranks,
'rank': node_rank * num_local_ranks + local_rank,
}
if 'device_id' in sig.parameters:
params['device_id'] = torch.device(f"cuda:{local_rank}")
dist.init_process_group(**params)
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cuda') torch.set_default_device('cuda')
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment