"tools/vscode:/vscode.git/clone" did not exist on "574446aec2ef3aec28cac7fef42b3365f1bee906"
Commit 2d250fbf authored by Rick Ho's avatar Rick Ho
Browse files

make test run on nccl version, but fails in correctness

parent 293eef6d
...@@ -13,50 +13,60 @@ ...@@ -13,50 +13,60 @@
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h> #include <nccl.h>
void moe_cuda_expert_exchange_impl( void moe_cuda_expert_exchange_impl(
const int* local_expert_count, const long* local_expert_count,
int* global_expert_count, long* global_expert_count,
int* fwd_expert_count, int num_expert, int world_size,
int num_expert, int world_size) { CudaStreamManager* smgr) {
MPI_Alltoall(local_expert_count, num_expert, MPI_INT, NCCL_SAFE_CALL(ncclGroupStart());
global_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD); for (int i = 0; i < world_size; ++i) {
for (int i = 0; i < num_expert; ++i) { NCCL_SAFE_CALL(ncclSend(
for (int j = 0; j < world_size; ++j) { local_expert_count + num_expert * i,
fwd_expert_count[i] += global_expert_count[i + j * num_expert]; num_expert,
} ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + num_expert * i,
num_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
} }
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
} }
std::vector<torch::Tensor> moe_cuda_expert_exchange( std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
long num_expert, long n_workers) { long num_expert, long n_workers) {
auto global_expert_count = torch::empty_like(local_expert_count); auto global_expert_count = torch::empty_like(local_expert_count);
auto fwe_options = torch::TensorOptions() auto smgr = getCudaStreamManager(local_expert_count.device().index());
.dtype(local_expert_count.dtype());
auto fwd_expert_count = torch::zeros({num_expert}, fwe_options);
moe_cuda_expert_exchange_impl( moe_cuda_expert_exchange_impl(
local_expert_count.data_ptr<int>(), local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<int>(), global_expert_count.data_ptr<long>(),
fwd_expert_count.data_ptr<int>(), num_expert, n_workers,
num_expert, n_workers); smgr);
return {global_expert_count, fwd_expert_count}; return {global_expert_count};
} }
template<typename scalar_t> template<typename scalar_t>
void moe_cuda_global_scatter_impl( void moe_cuda_global_scatter_impl(
const scalar_t* local_input_buf, const scalar_t* local_input_buf,
const int* local_expert_count, const long* local_expert_count,
const int* global_expert_count, const long* global_expert_count,
scalar_t* input_buf, scalar_t* input_buf,
size_t in_feat, size_t num_expert, size_t world_size, size_t in_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
// assert world_size > 1 // assert world_size > 1
int recv_ptr = 0; int recv_ptr = 0;
/* TODO: may save for backward */ /* TODO: may save for backward */
int *expert_ptr = new int[num_expert * world_size]; long*expert_ptr = new long[num_expert * world_size];
expert_ptr[0] = 0; expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) { for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1]; expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
...@@ -106,8 +116,8 @@ std::vector<torch::Tensor> moe_cuda_global_scatter( ...@@ -106,8 +116,8 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
"moe_cuda_global_scatter", ([&] { "moe_cuda_global_scatter", ([&] {
moe_cuda_global_scatter_impl<scalar_t>( moe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(), local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<int>(), global_expert_count.data_ptr<long>(),
global_input_buf.data_ptr<scalar_t>(), global_input_buf.data_ptr<scalar_t>(),
in_feat, num_expert, n_workers, in_feat, num_expert, n_workers,
smgr smgr
...@@ -119,14 +129,14 @@ std::vector<torch::Tensor> moe_cuda_global_scatter( ...@@ -119,14 +129,14 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
template<typename scalar_t> template<typename scalar_t>
void moe_cuda_global_gather_impl( void moe_cuda_global_gather_impl(
const scalar_t* output_buf, const scalar_t* output_buf,
const int* local_expert_count, const long* local_expert_count,
const int* global_expert_count, const long* global_expert_count,
scalar_t* local_output_buf, scalar_t* local_output_buf,
size_t out_feat, size_t num_expert, size_t world_size, size_t out_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
int send_ptr = 0; long send_ptr = 0;
/* TODO: may save for backward */ /* TODO: may save for backward */
int *expert_ptr = new int[num_expert * world_size]; long *expert_ptr = new long[num_expert * world_size];
expert_ptr[0] = 0; expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) { for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1]; expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
...@@ -176,8 +186,8 @@ std::vector<torch::Tensor> moe_cuda_global_gather( ...@@ -176,8 +186,8 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
"moe_cuda_global_gather", ([&] { "moe_cuda_global_gather", ([&] {
moe_cuda_global_gather_impl<scalar_t>( moe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(), output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(), local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<int>(), global_expert_count.data_ptr<long>(),
local_output_buf.data_ptr<scalar_t>(), local_output_buf.data_ptr<scalar_t>(),
out_feat, num_expert, n_workers, out_feat, num_expert, n_workers,
smgr smgr
...@@ -186,4 +196,9 @@ std::vector<torch::Tensor> moe_cuda_global_gather( ...@@ -186,4 +196,9 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
return {local_output_buf,}; return {local_output_buf,};
} }
void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
auto smgr = getCudaStreamManager(0);
smgr->ensure((void*)&p, t.device());
}
#endif #endif
...@@ -41,6 +41,9 @@ std::vector<torch::Tensor> moe_cuda_global_gather( ...@@ -41,6 +41,9 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
long batch_size, long n_workers); long batch_size, long n_workers);
#include <c10d/ProcessGroupNCCL.hpp>
void moe_ensure_nccl(c10d::ProcessGroupNCCL&, torch::Tensor t);
std::vector<torch::Tensor> moe_cuda_expert_exchange( std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
long num_expert, long n_workers); long num_expert, long n_workers);
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "cublas_wrapper.h" #include "cublas_wrapper.h"
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h> #include <nccl.h>
template<typename scalar_t> template<typename scalar_t>
...@@ -24,8 +23,8 @@ void moe_cuda_global_fused_forward_impl( ...@@ -24,8 +23,8 @@ void moe_cuda_global_fused_forward_impl(
scalar_t* global_input_buf, scalar_t* global_input_buf,
scalar_t* global_output_buf, scalar_t* global_output_buf,
scalar_t* output_buf, scalar_t* output_buf,
const int* local_expert_count, const long* local_expert_count,
const int* global_expert_count, const long* global_expert_count,
long in_feat, long out_feat, long in_feat, long out_feat,
long num_expert, long world_size, long num_expert, long world_size,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
...@@ -136,8 +135,8 @@ std::vector<torch::Tensor> moe_cuda_global_fused_forward( ...@@ -136,8 +135,8 @@ std::vector<torch::Tensor> moe_cuda_global_fused_forward(
global_input_buf.data_ptr<scalar_t>(), global_input_buf.data_ptr<scalar_t>(),
global_output_buf.data_ptr<scalar_t>(), global_output_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(), output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(), local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<int>(), global_expert_count.data_ptr<long>(),
in_feat, out_feat, num_expert, n_workers, in_feat, out_feat, num_expert, n_workers,
smgr); smgr);
})); }));
......
...@@ -38,16 +38,30 @@ class MOELocal(Function): ...@@ -38,16 +38,30 @@ class MOELocal(Function):
class MOEGlobal(Function): class MOEGlobal(Function):
@staticmethod @staticmethod
def forward(ctx, inp, gate, weight, world_size): def forward(ctx, inp, gate, weight, world_size):
fmoe_cuda.ensure_nccl(
torch.distributed.distributed_c10d._default_pg, inp)
num_expert = weight.shape[0] num_expert = weight.shape[0]
local_expert_count, pos = fmoe_cuda.expert_count(gate, # local_expert_count, pos = fmoe_cuda.expert_count(gate,
world_size * num_expert) # world_size * num_expert)
global_expert_count, fwd_expert_count = fmoe_cuda.expert_exchange( _, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
local_expert_count = torch.zeros(weight.shape[0] * world_size,
device=weight.device, dtype=torch.long)
local_expert_count.index_put_((gate_idx.long(), ), gate_count)
global_expert_count, = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size) local_expert_count, num_expert, world_size)
print('Local {} Global {}'.format(local_expert_count, global_expert_count))
fwd_expert_count = global_expert_count.view(num_expert,
world_size).sum(dim=1).cpu()
fwd_batch_size = int(fwd_expert_count.sum().item()) fwd_batch_size = int(fwd_expert_count.sum().item())
local_input_buf, = fmoe_cuda.local_scatter(inp, pos) local_input_buf, = fmoe_cuda.local_scatter(inp, pos)
local_expert_count = local_expert_count.cpu()
global_expert_count = global_expert_count.cpu()
local_output_buf, global_input_buf = fmoe_cuda.global_fused_forward( local_output_buf, global_input_buf = fmoe_cuda.global_fused_forward(
local_input_buf, weight, local_input_buf, weight,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
......
...@@ -8,7 +8,6 @@ cxx_flags = [ ...@@ -8,7 +8,6 @@ cxx_flags = [
] ]
if os.environ.get('USE_NCCL', '0') == '1': if os.environ.get('USE_NCCL', '0') == '1':
cxx_flags.append('-DMOE_USE_NCCL') cxx_flags.append('-DMOE_USE_NCCL')
os.environ['CXX'] = 'mpicxx'
if __name__ == '__main__': if __name__ == '__main__':
setuptools.setup( setuptools.setup(
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
if [ ! -z $OMPI_COMM_WORLD_LOCAL_RANK ] if [ ! -z $OMPI_COMM_WORLD_LOCAL_RANK ]
then then
export CUDA_VISIBLE_DEVICES=$OMPI_COMM_WORLD_LOCAL_RANK export CUDA_VISIBLE_DEVICES=$OMPI_COMM_WORLD_LOCAL_RANK
export MASTER_ADDR=localhost
export MASTER_PORT=36666
fi fi
if [ -z $OMPI_COMM_WORLD_RANK ] if [ -z $OMPI_COMM_WORLD_RANK ]
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
from torch import nn from torch import nn
import time import time
import sys import sys
import os
dev_name_default = 'cuda:0' dev_name_default = 'cuda:0'
...@@ -105,10 +106,10 @@ def test(): ...@@ -105,10 +106,10 @@ def test():
if world_size == 1: if world_size == 1:
moe_raw.weight.data = moe.weight.data.clone() moe_raw.weight.data = moe.weight.data.clone()
else: else:
weight_array = [torch.empty_like(moe.weight.data).cpu() weight_array = [torch.empty_like(moe.weight.data)
for _ in range(world_size)] for _ in range(world_size)]
torch.distributed.all_gather(weight_array, moe.weight.data.cpu()) torch.distributed.all_gather(weight_array, moe.weight.data)
moe_raw.weight.data = torch.cat(weight_array, dim=0).cuda() moe_raw.weight.data = torch.cat(weight_array, dim=0)
inp = torch.rand(batch_size, in_feat).cuda() inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, gate = torch.randint(low=0,
...@@ -124,13 +125,20 @@ def test(): ...@@ -124,13 +125,20 @@ def test():
if world_size > 1: if world_size > 1:
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
ou, wg, lwg, lbg = raw_out ou, wg, lwg, lbg = raw_out
wg = wg.cpu()
torch.distributed.all_reduce(wg) torch.distributed.all_reduce(wg)
wg = wg[rank * num_expert:(rank + 1)* num_expert] wg = wg[rank * num_expert:(rank + 1)* num_expert]
raw_out = ou, wg.cuda(), lwg, lbg raw_out = ou, wg, lwg, lbg
else:
rank = 0
for name, mo, ro in zip(names, moe_out, raw_out): for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum() err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err)) print('Rank {} {} abs err {}'.format(rank, name, err))
if err > 1e-3:
sys.stderr.write('=========== moe out ==============\n')
sys.stderr.write('{}'.format(mo))
sys.stderr.write('=========== raw out ==============\n')
sys.stderr.write('{}'.format(ro))
return
def test_dp(): def test_dp():
...@@ -158,7 +166,9 @@ def test_dp(): ...@@ -158,7 +166,9 @@ def test_dp():
if __name__ == '__main__': if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi') os.environ['RANK'] = os.environ.get('OMPI_COMM_WORLD_RANK', 0)
os.environ['WORLD_SIZE'] = os.environ.get('OMPI_COMM_WORLD_SIZE', 1)
torch.distributed.init_process_group(backend='nccl')
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
if len(sys.argv) >= 2: if len(sys.argv) >= 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment