Commit bf2fd0c0 authored by Rick Ho's avatar Rick Ho
Browse files

support multiple pytorch versions prviate apis

parent 481f5c4f
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .utils import get_torch_default_comm
class DistributedGroupedDataParallel(nn.Module): class DistributedGroupedDataParallel(nn.Module):
...@@ -17,9 +18,9 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -17,9 +18,9 @@ class DistributedGroupedDataParallel(nn.Module):
if dp_group is not None: if dp_group is not None:
self.comms['dp'] = dp_group self.comms['dp'] = dp_group
else: else:
self.comms['dp'] = torch.distributed.distributed_c10d._get_default_group() self.comms['dp'] = get_torch_default_comm()
if world_group is None: if world_group is None:
self.comms['world'] = torch.distributed.distributed_c10d._get_default_group() self.comms['world'] = get_torch_default_comm()
else: else:
self.comms['world'] = world_group self.comms['world'] = world_group
......
...@@ -7,6 +7,7 @@ computation. ...@@ -7,6 +7,7 @@ computation.
import torch import torch
from torch.autograd import Function from torch.autograd import Function
import fmoe_cuda import fmoe_cuda
from .utils import get_torch_default_comm
def moe_prepare_forward(gate, num_expert, world_size, comm=None): def moe_prepare_forward(gate, num_expert, world_size, comm=None):
...@@ -21,7 +22,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None): ...@@ -21,7 +22,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
comm: the communicator of all workers in the expert-parallel group. comm: the communicator of all workers in the expert-parallel group.
""" """
if comm is None: if comm is None:
comm = torch.distributed.distributed_c10d._get_default_group() comm = get_torch_default_comm()
if world_size > 1: if world_size > 1:
fmoe_cuda.ensure_nccl(comm, gate) fmoe_cuda.ensure_nccl(comm, gate)
......
import torch.distributed as dist
def get_torch_default_comm():
try:
comm = dist.distributed_c10d._get_default_group()
return comm
except Exception as e:
print('Error {}'.format(e))
pass
try:
comm = dist.distributed_c10d._default_pg
if comm is not None:
return comm
except Exception as _:
pass
raise RuntimeError('Unsupported PyTorch version')
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment