Commit 2d250fbf authored by Rick Ho's avatar Rick Ho
Browse files

make test run on nccl version, but fails in correctness

parent 293eef6d
......@@ -13,50 +13,60 @@
#include "cuda_stream_manager.h"
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
void moe_cuda_expert_exchange_impl(
const int* local_expert_count,
int* global_expert_count,
int* fwd_expert_count,
int num_expert, int world_size) {
MPI_Alltoall(local_expert_count, num_expert, MPI_INT,
global_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < world_size; ++j) {
fwd_expert_count[i] += global_expert_count[i + j * num_expert];
}
const long* local_expert_count,
long* global_expert_count,
int num_expert, int world_size,
CudaStreamManager* smgr) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int i = 0; i < world_size; ++i) {
NCCL_SAFE_CALL(ncclSend(
local_expert_count + num_expert * i,
num_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + num_expert * i,
num_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
}
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers) {
auto global_expert_count = torch::empty_like(local_expert_count);
auto fwe_options = torch::TensorOptions()
.dtype(local_expert_count.dtype());
auto fwd_expert_count = torch::zeros({num_expert}, fwe_options);
auto smgr = getCudaStreamManager(local_expert_count.device().index());
moe_cuda_expert_exchange_impl(
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
fwd_expert_count.data_ptr<int>(),
num_expert, n_workers);
return {global_expert_count, fwd_expert_count};
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
num_expert, n_workers,
smgr);
return {global_expert_count};
}
template<typename scalar_t>
void moe_cuda_global_scatter_impl(
const scalar_t* local_input_buf,
const int* local_expert_count,
const int* global_expert_count,
const long* local_expert_count,
const long* global_expert_count,
scalar_t* input_buf,
size_t in_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
// assert world_size > 1
int recv_ptr = 0;
/* TODO: may save for backward */
int *expert_ptr = new int[num_expert * world_size];
long*expert_ptr = new long[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
......@@ -106,8 +116,8 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
"moe_cuda_global_scatter", ([&] {
moe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
global_input_buf.data_ptr<scalar_t>(),
in_feat, num_expert, n_workers,
smgr
......@@ -119,14 +129,14 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
template<typename scalar_t>
void moe_cuda_global_gather_impl(
const scalar_t* output_buf,
const int* local_expert_count,
const int* global_expert_count,
const long* local_expert_count,
const long* global_expert_count,
scalar_t* local_output_buf,
size_t out_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
int send_ptr = 0;
long send_ptr = 0;
/* TODO: may save for backward */
int *expert_ptr = new int[num_expert * world_size];
long *expert_ptr = new long[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
......@@ -176,8 +186,8 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
"moe_cuda_global_gather", ([&] {
moe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
local_output_buf.data_ptr<scalar_t>(),
out_feat, num_expert, n_workers,
smgr
......@@ -186,4 +196,9 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
return {local_output_buf,};
}
void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
auto smgr = getCudaStreamManager(0);
smgr->ensure((void*)&p, t.device());
}
#endif
......@@ -41,6 +41,9 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor global_expert_count,
long batch_size, long n_workers);
#include <c10d/ProcessGroupNCCL.hpp>
void moe_ensure_nccl(c10d::ProcessGroupNCCL&, torch::Tensor t);
std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers);
......
......@@ -14,7 +14,6 @@
#include "cublas_wrapper.h"
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
template<typename scalar_t>
......@@ -24,8 +23,8 @@ void moe_cuda_global_fused_forward_impl(
scalar_t* global_input_buf,
scalar_t* global_output_buf,
scalar_t* output_buf,
const int* local_expert_count,
const int* global_expert_count,
const long* local_expert_count,
const long* global_expert_count,
long in_feat, long out_feat,
long num_expert, long world_size,
CudaStreamManager* smgr) {
......@@ -136,8 +135,8 @@ std::vector<torch::Tensor> moe_cuda_global_fused_forward(
global_input_buf.data_ptr<scalar_t>(),
global_output_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
in_feat, out_feat, num_expert, n_workers,
smgr);
}));
......
......@@ -38,16 +38,30 @@ class MOELocal(Function):
class MOEGlobal(Function):
@staticmethod
def forward(ctx, inp, gate, weight, world_size):
fmoe_cuda.ensure_nccl(
torch.distributed.distributed_c10d._default_pg, inp)
num_expert = weight.shape[0]
local_expert_count, pos = fmoe_cuda.expert_count(gate,
world_size * num_expert)
global_expert_count, fwd_expert_count = fmoe_cuda.expert_exchange(
# local_expert_count, pos = fmoe_cuda.expert_count(gate,
# world_size * num_expert)
_, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
local_expert_count = torch.zeros(weight.shape[0] * world_size,
device=weight.device, dtype=torch.long)
local_expert_count.index_put_((gate_idx.long(), ), gate_count)
global_expert_count, = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size)
print('Local {} Global {}'.format(local_expert_count, global_expert_count))
fwd_expert_count = global_expert_count.view(num_expert,
world_size).sum(dim=1).cpu()
fwd_batch_size = int(fwd_expert_count.sum().item())
local_input_buf, = fmoe_cuda.local_scatter(inp, pos)
local_expert_count = local_expert_count.cpu()
global_expert_count = global_expert_count.cpu()
local_output_buf, global_input_buf = fmoe_cuda.global_fused_forward(
local_input_buf, weight,
local_expert_count, global_expert_count,
......
......@@ -8,7 +8,6 @@ cxx_flags = [
]
if os.environ.get('USE_NCCL', '0') == '1':
cxx_flags.append('-DMOE_USE_NCCL')
os.environ['CXX'] = 'mpicxx'
if __name__ == '__main__':
setuptools.setup(
......
......@@ -2,6 +2,8 @@
if [ ! -z $OMPI_COMM_WORLD_LOCAL_RANK ]
then
export CUDA_VISIBLE_DEVICES=$OMPI_COMM_WORLD_LOCAL_RANK
export MASTER_ADDR=localhost
export MASTER_PORT=36666
fi
if [ -z $OMPI_COMM_WORLD_RANK ]
......
......@@ -4,6 +4,7 @@ import torch
from torch import nn
import time
import sys
import os
dev_name_default = 'cuda:0'
......@@ -105,10 +106,10 @@ def test():
if world_size == 1:
moe_raw.weight.data = moe.weight.data.clone()
else:
weight_array = [torch.empty_like(moe.weight.data).cpu()
weight_array = [torch.empty_like(moe.weight.data)
for _ in range(world_size)]
torch.distributed.all_gather(weight_array, moe.weight.data.cpu())
moe_raw.weight.data = torch.cat(weight_array, dim=0).cuda()
torch.distributed.all_gather(weight_array, moe.weight.data)
moe_raw.weight.data = torch.cat(weight_array, dim=0)
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0,
......@@ -124,13 +125,20 @@ def test():
if world_size > 1:
rank = torch.distributed.get_rank()
ou, wg, lwg, lbg = raw_out
wg = wg.cpu()
torch.distributed.all_reduce(wg)
wg = wg[rank * num_expert:(rank + 1)* num_expert]
raw_out = ou, wg.cuda(), lwg, lbg
raw_out = ou, wg, lwg, lbg
else:
rank = 0
for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err))
print('Rank {} {} abs err {}'.format(rank, name, err))
if err > 1e-3:
sys.stderr.write('=========== moe out ==============\n')
sys.stderr.write('{}'.format(mo))
sys.stderr.write('=========== raw out ==============\n')
sys.stderr.write('{}'.format(ro))
return
def test_dp():
......@@ -158,7 +166,9 @@ def test_dp():
if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi')
os.environ['RANK'] = os.environ.get('OMPI_COMM_WORLD_RANK', 0)
os.environ['WORLD_SIZE'] = os.environ.get('OMPI_COMM_WORLD_SIZE', 1)
torch.distributed.init_process_group(backend='nccl')
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if len(sys.argv) >= 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment