Commit 15f98a10 authored by Rick Ho's avatar Rick Ho
Browse files

adapt with pytorch 1.8.0 (deprecated 1.6.0)

parent 585604fe
...@@ -18,7 +18,7 @@ class HackNCCLGroup: public c10d::ProcessGroupNCCL { ...@@ -18,7 +18,7 @@ class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public: public:
ncclComm_t getcomm(at::Device dev) { ncclComm_t getcomm(at::Device dev) {
auto key = std::to_string(dev.index()); auto key = std::to_string(dev.index());
auto v = getNCCLComm(key, {dev}); auto v = getNCCLComm(key, {dev}, c10d::OpType::ALLTOALL);
if (v.size() == 0) { if (v.size() == 0) {
std::cerr << "PyTorch has nothing\n"; std::cerr << "PyTorch has nothing\n";
return 0; return 0;
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <helper_cuda.h> #include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
......
...@@ -17,9 +17,9 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -17,9 +17,9 @@ class DistributedGroupedDataParallel(nn.Module):
if dp_group is not None: if dp_group is not None:
self.comms['dp'] = dp_group self.comms['dp'] = dp_group
else: else:
self.comms['dp'] = torch.distributed.distributed_c10d._default_pg self.comms['dp'] = torch.distributed.distributed_c10d._get_default_group()
if world_group is None: if world_group is None:
self.comms['world'] = torch.distributed.distributed_c10d._default_pg self.comms['world'] = torch.distributed.distributed_c10d._get_default_group()
else: else:
self.comms['world'] = world_group self.comms['world'] = world_group
......
...@@ -21,7 +21,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None): ...@@ -21,7 +21,7 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
comm: the communicator of all workers in the expert-parallel group. comm: the communicator of all workers in the expert-parallel group.
""" """
if comm is None: if comm is None:
comm = torch.distributed.distributed_c10d._default_pg comm = torch.distributed.distributed_c10d._get_default_group()
if world_size > 1: if world_size > 1:
fmoe_cuda.ensure_nccl(comm, gate) fmoe_cuda.ensure_nccl(comm, gate)
......
...@@ -4,7 +4,7 @@ from .distributed import DistributedGroupedDataParallel ...@@ -4,7 +4,7 @@ from .distributed import DistributedGroupedDataParallel
def create_moe_mlp(args, group): def create_moe_mlp(args, group):
assert ( assert (
args.seq_length * args.batch_size % args.model_parallel_size == 0 args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size == 0
), "Batch size x sequence length should be multiple of mp size" ), "Batch size x sequence length should be multiple of mp size"
if not args.distributed_experts: if not args.distributed_experts:
world_size = 1 world_size = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment