Commit bf2fd0c0 authored by Rick Ho's avatar Rick Ho
Browse files

support multiple pytorch versions prviate apis

parent 481f5c4f
import torch
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .utils import get_torch_default_comm
class DistributedGroupedDataParallel(nn.Module):
......@@ -17,9 +18,9 @@ class DistributedGroupedDataParallel(nn.Module):
if dp_group is not None:
self.comms['dp'] = dp_group
else:
self.comms['dp'] = torch.distributed.distributed_c10d._get_default_group()
self.comms['dp'] = get_torch_default_comm()
if world_group is None:
self.comms['world'] = torch.distributed.distributed_c10d._get_default_group()
self.comms['world'] = get_torch_default_comm()
else:
self.comms['world'] = world_group
......
......@@ -7,6 +7,7 @@ computation.
import torch
from torch.autograd import Function
import fmoe_cuda
from .utils import get_torch_default_comm
def moe_prepare_forward(gate, num_expert, world_size, comm=None):
......@@ -21,7 +22,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
comm: the communicator of all workers in the expert-parallel group.
"""
if comm is None:
comm = torch.distributed.distributed_c10d._get_default_group()
comm = get_torch_default_comm()
if world_size > 1:
fmoe_cuda.ensure_nccl(comm, gate)
......
import torch.distributed as dist
def get_torch_default_comm():
try:
comm = dist.distributed_c10d._get_default_group()
return comm
except Exception as e:
print('Error {}'.format(e))
pass
try:
comm = dist.distributed_c10d._default_pg
if comm is not None:
return comm
except Exception as _:
pass
raise RuntimeError('Unsupported PyTorch version')
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment