Commit cd371d31 authored by Shangyan Zhou's avatar Shangyan Zhou
Browse files

Move import.

parent bf4a4a21
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from typing import Optional from typing import Optional
import inspect
def init_dist(local_rank: int, num_local_ranks: int): def init_dist(local_rank: int, num_local_ranks: int):
...@@ -14,7 +15,6 @@ def init_dist(local_rank: int, num_local_ranks: int): ...@@ -14,7 +15,6 @@ def init_dist(local_rank: int, num_local_ranks: int):
node_rank = int(os.getenv('RANK', 0)) node_rank = int(os.getenv('RANK', 0))
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
import inspect
sig = inspect.signature(dist.init_process_group) sig = inspect.signature(dist.init_process_group)
params = { params = {
'backend': 'nccl', 'backend': 'nccl',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment