Commit bb3c8966 authored by Rick Ho's avatar Rick Ho
Browse files

reconstruct fmoe cuda code

parent 4c90e6e8
#include <unordered_map>
#include <mutex>
#include <cassert>
#include <thread>
#include <iostream>
#include "cuda_stream_manager.h"
#define SMGR_N_STREAMS 16
cudaStream_t CudaStreamManager::stream(size_t idx) {
return this->streams[idx % SMGR_N_STREAMS];
}
cublasHandle_t CudaStreamManager::handle(size_t idx) {
return this->handles[idx % SMGR_N_STREAMS];
}
void CudaStreamManager::sync(int idx) {
for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]);
}
}
void CudaStreamManager::setup(const int device) {
#ifdef MOE_USE_NCCL
this->ncclgood = 0;
#endif
this->device = device;
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS];
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
}
void CudaStreamManager::destroy() {
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i]));
}
delete[] streams;
delete[] handles;
}
std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;
CudaStreamManager* getCudaStreamManager(const int device) {
auto it = smgrs.find(device);
if (it == smgrs.end()) {
smgr_mtx.lock();
it = smgrs.find(device);
if (it == smgrs.end()) {
auto smgr = new CudaStreamManager(device);
smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
smgr_mtx.unlock();
return smgr;
} else {
smgr_mtx.unlock();
}
}
return it->second;
}
#include <iostream>
#include <vector>
#include <torch/extension.h>
#ifdef FMOE_USE_NCCL
#include <c10d/ProcessGroupNCCL.hpp>
std::vector<torch::Tensor> _expert_exchange(
torch::Tensor local_expert_count,
long n_expert, long n_workers);
std::vector<torch::Tensor> _global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
std::vector<torch::Tensor> _global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t);
#endif // FMOE_USE_NCCL
std::vector<torch::Tensor> _expert_count(
torch::Tensor gate,
size_t num_expert);
std::vector<torch::Tensor> _local_scatter(
torch::Tensor input,
torch::Tensor pos);
std::vector<torch::Tensor> _local_gather(
torch::Tensor output_buf,
torch::Tensor pos);
std::vector<torch::Tensor> _linear_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
std::vector<torch::Tensor> _linear_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef FMOE_USE_NCCL
m.def("expert_exchange", &_expert_exchange, "FastMoE expert exchange (CUDA)");
m.def("global_scatter", &_global_scatter, "FastMoE global scatter (CUDA)");
m.def("global_gather", &_global_gather, "FastMoE global gather (CUDA)");
m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm");
#endif
m.def("expert_count", &_expert_count, "FastMoE expert count (CUDA)");
m.def("local_scatter", &_local_scatter, "FastMoE local scatter (CUDA)");
m.def("local_gather", &_local_gather, "FastMoE local gather (CUDA)");
m.def("linear_forward", &_linear_forward, "FastMoE forward (CUDA)");
m.def("linear_backward", &_linear_backward, "FastMoE backward (CUDA)");
}
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#include "cublas_wrapper.h"
#ifdef FMOE_USE_NCCL
#include <nccl.h>
template<typename scalar_t>
void moe_cuda_global_fused_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
scalar_t* global_input_buf,
scalar_t* global_output_buf,
scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
long in_feat, long out_feat,
long num_expert, long world_size,
CudaStreamManager* smgr) {
int ptr = 0;
int send_ptr = 0;
int recv_ptr = 0;
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
scalar_t alpha = 1, beta = 0;
for (int i = 0; i < num_expert; ++i) {
int expert_count = 0;
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
global_input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
recv_ptr += global_expert_count[idx];
expert_count += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count, in_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
global_input_buf + ptr * in_feat, in_feat,
&beta,
global_output_buf + out_feat * ptr, out_feat
));
ptr += expert_count;
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
global_output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(num_expert);
}
std::vector<torch::Tensor> moe_cuda_global_fused_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers) {
const auto num_expert = local_expert_count.size(0) / n_workers;
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
auto smgr = getCudaStreamManager(input_buf.device().index());
auto global_input_buf = input_buf.new_empty({global_batch_size, in_feat});
auto global_output_buf = input_buf.new_empty({global_batch_size, out_feat});
auto output_buf = input_buf.new_empty({local_batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"moe_cuda_global_fused_forward", ([&] {
moe_cuda_global_fused_forward_impl(
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
global_input_buf.data_ptr<scalar_t>(),
global_output_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
in_feat, out_feat, num_expert, n_workers,
smgr);
}));
return {output_buf, global_input_buf};
}
#endif
#include "global_exchange.h"
#include "utils/fmoe_utils.h"
#include <torch/extension.h>
#ifdef FMOE_USE_NCCL
#include <nccl.h>
std::vector<torch::Tensor> _expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers) {
auto global_expert_count = torch::empty_like(local_expert_count);
auto smgr = getCudaStreamManager(local_expert_count.device().index());
fmoe_cuda_expert_exchange_impl(
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
num_expert, n_workers,
smgr);
return {global_expert_count};
}
std::vector<torch::Tensor> _global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
CHECK_INPUT(input_buf);
auto num_expert = local_expert_count.size(0) / n_workers;
auto in_feat = input_buf.size(1);
auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_global_scatter", ([&] {
fmoe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
global_input_buf.data_ptr<scalar_t>(),
in_feat, num_expert, n_workers,
smgr
);
}));
return {global_input_buf,};
}
std::vector<torch::Tensor> _global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
CHECK_INPUT(output_buf);
auto num_expert = local_expert_count.size(0) / n_workers;
auto out_feat = output_buf.size(1);
auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(),
"fmoe_cuda_global_gather", ([&] {
fmoe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
local_output_buf.data_ptr<scalar_t>(),
out_feat, num_expert, n_workers,
smgr
);
}));
return {local_output_buf,};
}
#include <c10d/ProcessGroupNCCL.hpp>
class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
ncclComm_t getcomm(at::Device dev) {
ncclUniqueId ncclID;
int rank = getRank();
if (rank == 0) {
ncclGetUniqueId(&ncclID);
}
broadcastUniqueNCCLID(&ncclID,
c10d::OpType::SEND,
"fastmoe_nccl_comm",
rank);
ncclComm_t comm;
ncclCommInitRank(&comm, getSize(), ncclID, rank);
return comm;
}
};
void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
auto smgr = getCudaStreamManager(t.device().index());
if (smgr->ncclgood) {
return;
}
HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
smgr->ncclcomm = h->getcomm(t.device());
if (smgr->ncclcomm != 0) {
smgr->ncclgood = 1;
} else {
std::cerr << "Nccl initialization failed\n";
}
}
#endif // FMOE_USE_NCCL
#include "stream_manager.h"
#ifdef FMOE_USE_NCCL
void fmoe_cuda_expert_exchange_impl(
const long* local_expert_count,
long* global_expert_count,
int num_expert, int world_size,
CudaStreamManager* smgr) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int i = 0; i < world_size; ++i) {
NCCL_SAFE_CALL(ncclSend(
local_expert_count + num_expert * i,
num_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + num_expert * i,
num_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
}
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
}
template<typename scalar_t>
void fmoe_cuda_global_scatter_impl(
const scalar_t* local_input_buf,
const long* local_expert_count,
const long* global_expert_count,
scalar_t* input_buf,
size_t in_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
// assert world_size > 1
int recv_ptr = 0;
/* TODO: may save for backward */
long*expert_ptr = new long[num_expert * world_size];
expert_ptr[0] = 0;
for (size_t i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (size_t i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (size_t j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
local_input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
recv_ptr += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
template<typename scalar_t>
void fmoe_cuda_global_gather_impl(
const scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
scalar_t* local_output_buf,
size_t out_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
long send_ptr = 0;
/* TODO: may save for backward */
long *expert_ptr = new long[num_expert * world_size];
expert_ptr[0] = 0;
for (size_t i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (size_t i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (size_t j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
#endif // FMOE_USE_NCCL
#include "local_exchange.cuh"
#include "utils/fmoe_utils.h"
#include <torch/extension.h>
std::vector<torch::Tensor> _expert_count(
torch::Tensor gate,
size_t num_expert) {
const auto batch_size = gate.size(0);
auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
auto expert_count = torch::empty(num_expert, ec_options);
auto pos_options = torch::TensorOptions()
.device(gate.device())
.dtype(torch::kInt32);
auto pos = torch::empty(batch_size, pos_options);
fmoe_cuda_expert_count_impl(
gate.data_ptr<int>(),
expert_count.data_ptr<int>(),
pos.data_ptr<int>(),
num_expert,
batch_size);
return {expert_count, pos};
}
std::vector<torch::Tensor> _local_scatter(
torch::Tensor input,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(input.device().index());
const auto batch_size = pos.size(0);
const auto in_feat = input.size(1);
auto opt = torch::TensorOptions()
.dtype(input.dtype())
.device(input.device());
auto input_buf = torch::empty({batch_size, in_feat}, opt);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fmoe_local_scatter",
([&] {
fmoe_cuda_local_scatter_impl<scalar_t>(
input.data_ptr<scalar_t>(),
pos.data_ptr<long>(),
input_buf.data_ptr<scalar_t>(),
batch_size,
in_feat,
smgr);
}));
return {input_buf,};
}
std::vector<torch::Tensor> _local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(output_buf.device().index());
const auto batch_size = pos.size(0);
const auto out_feat = output_buf.size(1);
auto opt = torch::TensorOptions()
.dtype(output_buf.dtype())
.device(output_buf.device());
auto output = torch::empty({batch_size, out_feat}, opt);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "fmoe_local_gather",
([&] {
fmoe_cuda_local_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
pos.data_ptr<long>(),
output.data_ptr<scalar_t>(),
batch_size,
out_feat,
smgr);
}));
return {output,};
}
#include "stream_manager.h"
#include "utils/helper_cuda.h"
template <typename scalar_t>
__global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
const long* offset, const scalar_t** ptrs) {
size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < n) {
ptrs[idx] = base + stride * offset[idx];
}
}
template <typename scalar_t>
__global__
void batch_scatter_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
void fmoe_cuda_expert_count_impl(
const int* d_gate,
int* expert_count,
int* d_pos,
const size_t num_expert,
const size_t batch_size) {
int *gate = new int[batch_size];
int *expert_ptr = new int[num_expert];
memset(expert_count, 0, sizeof(int) * num_expert);
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost));
for (int i = 0; i < batch_size; ++i) {
++expert_count[gate[i]];
}
expert_ptr[0] = 0;
for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
}
int *pos = new int[batch_size];
for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++;
}
for (int i = num_expert - 1; i > 0; --i) {
expert_ptr[i] = expert_ptr[i - 1];
}
expert_ptr[0] = 0;
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice));
delete [] gate;
delete [] expert_ptr;
}
template <typename scalar_t>
void fmoe_cuda_local_scatter_impl(
const scalar_t* input,
const long* d_pos,
scalar_t* input_buf,
const long batch_size,
const long in_feat,
CudaStreamManager* smgr) {
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
input_buf);
smgr->sync(1);
}
template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x];
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
void fmoe_cuda_local_gather_impl(
const scalar_t* output_buf,
const long* d_pos,
scalar_t* output,
const size_t batch_size,
const size_t out_feat,
CudaStreamManager* smgr) {
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
output);
smgr->sync(1);
}
#include <torch/extension.h>
#include <cstdio>
#include <iostream>
#include <vector>
#include "moe_cuda_kernel.h"
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe_expert_count(
torch::Tensor gate,
size_t num_expert) {
CHECK_INPUT(gate);
return moe_cuda_expert_count(gate, num_expert);
}
std::vector<torch::Tensor> moe_local_scatter(
torch::Tensor input,
torch::Tensor pos) {
CHECK_INPUT(input);
return moe_cuda_local_scatter(input, pos);
}
std::vector<torch::Tensor> moe_local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
CHECK_INPUT(output_buf);
return moe_cuda_local_gather(output_buf, pos);
}
std::vector<torch::Tensor> moe_forward(
torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count // [batch_size]
) {
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return moe_cuda_forward(input_buf, weight, expert_count);
}
std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count
) {
CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return moe_cuda_backward(grad_output_buf, input_buf, weight, expert_count);
}
#ifdef MOE_USE_NCCL
std::vector<torch::Tensor> moe_expert_exchange(
torch::Tensor local_expert_count,
size_t num_expert, size_t n_workers) {
return moe_cuda_expert_exchange(local_expert_count, num_expert, n_workers);
}
std::vector<torch::Tensor> moe_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
size_t batch_size, size_t n_workers) {
CHECK_INPUT(input_buf);
return moe_cuda_global_scatter(input_buf,
local_expert_count, global_expert_count,
batch_size, n_workers);
}
std::vector<torch::Tensor> moe_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
size_t batch_size, size_t n_workers) {
CHECK_INPUT(output_buf);
return moe_cuda_global_gather(output_buf,
local_expert_count, global_expert_count,
batch_size, n_workers);
}
std::vector<torch::Tensor> moe_global_fused_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers) {
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
return moe_cuda_global_fused_forward(
input_buf, weight, local_expert_count, global_expert_count,
global_batch_size, local_batch_size, n_workers);
}
#include <c10d/ProcessGroupNCCL.hpp>
#include "cuda_stream_manager.h"
class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
ncclComm_t getcomm(at::Device dev) {
auto key = std::to_string(dev.index());
#ifdef ENABLE_NCCL_P2P_SUPPORT
ncclUniqueId ncclID;
int rank = getRank();
if (rank == 0) {
ncclGetUniqueId(&ncclID);
}
broadcastUniqueNCCLID(&ncclID,
c10d::OpType::SEND,
"fastmoe_nccl_comm",
rank);
ncclComm_t comm;
ncclCommInitRank(&comm, getSize(), ncclID, rank);
return comm;
#else
auto v = getNCCLComm(key, {dev});
if (v.size() == 0) {
std::cerr << "PyTorch has nothing\n";
return 0;
}
int count;
ncclCommCount(v[0]->getNcclComm(), &count);
std::cerr << "PyTorch has " << v.size() << " comms, comm 0 size " << count << "\n";
return v[0]->getNcclComm();
#endif
}
};
void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
auto smgr = getCudaStreamManager(t.device().index());
if (smgr->ncclgood) {
return;
}
HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
smgr->ncclcomm = h->getcomm(t.device());
if (smgr->ncclcomm != 0) {
smgr->ncclgood = 1;
} else {
std::cerr << "Nccl initialization failed\n";
}
}
#endif // MOE_USE_NCCL
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("expert_count", &moe_expert_count, "MoE expert count (CUDA)");
m.def("local_scatter", &moe_local_scatter, "MoE local scatter (CUDA)");
m.def("local_gather", &moe_local_gather, "MoE local gather (CUDA)");
#ifdef MOE_USE_NCCL
m.def("expert_exchange", &moe_expert_exchange, "MoE expert exchange (CUDA)");
m.def("global_scatter", &moe_global_scatter, "MoE global scatter (CUDA)");
m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)");
m.def("global_fused_forward", &moe_global_fused_forward,
"MoE global gather (CUDA)");
m.def("ensure_nccl", &moe_ensure_nccl, "MoE ensure torch nccl comm");
#endif
m.def("forward", &moe_forward, "MoE forward (CUDA)");
m.def("backward", &moe_backward, "MoE backward (CUDA)");
}
\ No newline at end of file
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include "cuda_stream_manager.h"
#ifdef MOE_USE_NCCL
#include <nccl.h>
void moe_cuda_expert_exchange_impl(
const long* local_expert_count,
long* global_expert_count,
int num_expert, int world_size,
CudaStreamManager* smgr) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int i = 0; i < world_size; ++i) {
NCCL_SAFE_CALL(ncclSend(
local_expert_count + num_expert * i,
num_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + num_expert * i,
num_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
}
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers) {
auto global_expert_count = torch::empty_like(local_expert_count);
auto smgr = getCudaStreamManager(local_expert_count.device().index());
moe_cuda_expert_exchange_impl(
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
num_expert, n_workers,
smgr);
return {global_expert_count};
}
template<typename scalar_t>
void moe_cuda_global_scatter_impl(
const scalar_t* local_input_buf,
const long* local_expert_count,
const long* global_expert_count,
scalar_t* input_buf,
size_t in_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
// assert world_size > 1
int recv_ptr = 0;
/* TODO: may save for backward */
long*expert_ptr = new long[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
local_input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
recv_ptr += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
auto num_expert = local_expert_count.size(0) / n_workers;
auto in_feat = input_buf.size(1);
auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"moe_cuda_global_scatter", ([&] {
moe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
global_input_buf.data_ptr<scalar_t>(),
in_feat, num_expert, n_workers,
smgr
);
}));
return {global_input_buf,};
}
template<typename scalar_t>
void moe_cuda_global_gather_impl(
const scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
scalar_t* local_output_buf,
size_t out_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
long send_ptr = 0;
/* TODO: may save for backward */
long *expert_ptr = new long[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
auto num_expert = local_expert_count.size(0) / n_workers;
auto out_feat = output_buf.size(1);
auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(),
"moe_cuda_global_gather", ([&] {
moe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
local_output_buf.data_ptr<scalar_t>(),
out_feat, num_expert, n_workers,
smgr
);
}));
return {local_output_buf,};
}
#endif
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <c10/cuda/CUDAGuard.h>
#include "timer.hh"
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
template <typename scalar_t>
__global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
const long* offset, const scalar_t** ptrs) {
size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < n) {
ptrs[idx] = base + stride * offset[idx];
}
}
template <typename scalar_t>
__global__
void batch_scatter_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
void moe_cuda_expert_count_impl(
const int* d_gate,
int* expert_count,
int* d_pos,
const size_t num_expert,
const size_t batch_size) {
int *gate = new int[batch_size];
int *expert_ptr = new int[num_expert];
memset(expert_count, 0, sizeof(int) * num_expert);
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost));
for (int i = 0; i < batch_size; ++i) {
++expert_count[gate[i]];
}
expert_ptr[0] = 0;
for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
}
int *pos = new int[batch_size];
for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++;
}
for (int i = num_expert - 1; i > 0; --i) {
expert_ptr[i] = expert_ptr[i - 1];
}
expert_ptr[0] = 0;
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice));
delete [] gate;
delete [] expert_ptr;
}
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
const scalar_t* input,
const long* d_pos,
scalar_t* input_buf,
const long batch_size,
const long in_feat,
CudaStreamManager* smgr) {
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
input_buf);
smgr->sync(1);
}
template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x];
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
void moe_cuda_local_gather_impl(
const scalar_t* output_buf,
const long* d_pos,
scalar_t* output,
const size_t batch_size,
const size_t out_feat,
CudaStreamManager* smgr) {
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
output);
smgr->sync(1);
}
template <typename scalar_t>
void moe_cuda_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
const long* expert_count,
scalar_t* output_buf,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count[i], in_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
input_buf + ptr * in_feat, in_feat,
&beta,
output_buf + out_feat * ptr, out_feat
));
ptr += expert_count[i];
}
smgr->sync(num_expert);
}
template <typename scalar_t>
void moe_cuda_backward_impl(
const scalar_t* grad_output_buf,
const scalar_t* input_buf,
const scalar_t* weight,
const long* expert_count,
scalar_t* grad_input_buf,
scalar_t* grad_weight,
const size_t batch_size,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
cudaMemset(grad_weight + i * in_feat * out_feat, 0,
sizeof(scalar_t) * in_feat * out_feat);
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
// Backward input: g_i = w @ g_o
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N,
CUBLAS_OP_N,
in_feat, expert_count[i], out_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
grad_input_buf + in_feat * ptr, in_feat
));
// Backward weight: g_w = i @ g_o
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N,
CUBLAS_OP_T,
in_feat, out_feat, expert_count[i],
&alpha,
input_buf + in_feat * ptr, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
grad_weight + i * in_feat * out_feat, in_feat
));
ptr += expert_count[i];
}
smgr->sync(num_expert);
}
std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate,
size_t num_expert) {
const auto batch_size = gate.size(0);
auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
auto expert_count = torch::empty(num_expert, ec_options);
auto pos_options = torch::TensorOptions()
.device(gate.device())
.dtype(torch::kInt32);
auto pos = torch::empty(batch_size, pos_options);
moe_cuda_expert_count_impl(
gate.data_ptr<int>(),
expert_count.data_ptr<int>(),
pos.data_ptr<int>(),
num_expert,
batch_size);
return {expert_count, pos};
}
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(input.device().index());
const auto batch_size = pos.size(0);
const auto in_feat = input.size(1);
auto opt = torch::TensorOptions()
.dtype(input.dtype())
.device(input.device());
auto input_buf = torch::empty({batch_size, in_feat}, opt);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda",
([&] {
moe_cuda_local_scatter_impl<scalar_t>(
input.data_ptr<scalar_t>(),
pos.data_ptr<long>(),
input_buf.data_ptr<scalar_t>(),
batch_size,
in_feat,
smgr);
}));
return {input_buf,};
}
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(output_buf.device().index());
const auto batch_size = pos.size(0);
const auto out_feat = output_buf.size(1);
auto opt = torch::TensorOptions()
.dtype(output_buf.dtype())
.device(output_buf.device());
auto output = torch::empty({batch_size, out_feat}, opt);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda",
([&] {
moe_cuda_local_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
pos.data_ptr<long>(),
output.data_ptr<scalar_t>(),
batch_size,
out_feat,
smgr);
}));
return {output,};
}
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat);
#endif
auto out_options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
auto output = torch::empty({batch_size, out_feat}, out_options);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda",
([&] {
moe_cuda_forward_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(),
output.data_ptr<scalar_t>(),
in_feat,
out_feat,
num_expert,
smgr
);
}));
return {output, };
}
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld\n",
batch_size, num_expert, in_feat, out_feat);
#endif
auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_backward_impl<scalar_t>(
grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(),
grad_input_buf.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
batch_size,
in_feat,
out_feat,
num_expert,
smgr
);
}));
return {grad_input_buf, grad_weight};
}
#ifndef MOE_CUDA_KERNEL_H
#define MOE_CUDA_KERNEL_H
#include <vector>
#include <torch/extension.h>
#include <torch/torch.h>
#include "helper_cuda.h"
std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate, size_t num_expert);
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf,
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
#ifdef MOE_USE_NCCL
std::vector<torch::Tensor> moe_cuda_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers);
std::vector<torch::Tensor> moe_cuda_global_fused_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers);
#endif
#endif // MOE_CUDA_KERNEL_H
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#include "cublas_wrapper.h"
#ifdef MOE_USE_NCCL
#include <nccl.h>
template<typename scalar_t>
void moe_cuda_global_fused_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
scalar_t* global_input_buf,
scalar_t* global_output_buf,
scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
long in_feat, long out_feat,
long num_expert, long world_size,
CudaStreamManager* smgr) {
int ptr = 0;
int send_ptr = 0;
int recv_ptr = 0;
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
scalar_t alpha = 1, beta = 0;
for (int i = 0; i < num_expert; ++i) {
int expert_count = 0;
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
global_input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
recv_ptr += global_expert_count[idx];
expert_count += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count, in_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
global_input_buf + ptr * in_feat, in_feat,
&beta,
global_output_buf + out_feat * ptr, out_feat
));
ptr += expert_count;
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
global_output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(num_expert);
}
std::vector<torch::Tensor> moe_cuda_global_fused_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers) {
const auto num_expert = local_expert_count.size(0) / n_workers;
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
auto smgr = getCudaStreamManager(input_buf.device().index());
auto global_input_buf = input_buf.new_empty({global_batch_size, in_feat});
auto global_output_buf = input_buf.new_empty({global_batch_size, out_feat});
auto output_buf = input_buf.new_empty({local_batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"moe_cuda_global_fused_forward", ([&] {
moe_cuda_global_fused_forward_impl(
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
global_input_buf.data_ptr<scalar_t>(),
global_output_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
in_feat, out_feat, num_expert, n_workers,
smgr);
}));
return {output_buf, global_input_buf};
}
#endif
#include "parallel_linear.h"
#include "utils/fmoe_utils.h"
#include <torch/extension.h>
std::vector<torch::Tensor> _linear_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count
) {
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef FMOE_DEBUG
printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat);
#endif
auto out_options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
auto output = torch::empty({batch_size, out_feat}, out_options);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "fmoe_linear_forward",
([&] {
fmoe_cuda_forward_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(),
output.data_ptr<scalar_t>(),
in_feat,
out_feat,
num_expert,
smgr
);
}));
return {output, };
}
std::vector<torch::Tensor> _linear_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count
) {
CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef FMOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld\n",
batch_size, num_expert, in_feat, out_feat);
#endif
auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "ffmoe_linear_backward", ([&] {
fmoe_cuda_backward_impl<scalar_t>(
grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(),
grad_input_buf.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
batch_size,
in_feat,
out_feat,
num_expert,
smgr
);
}));
return {grad_input_buf, grad_weight};
}
#include "stream_manager.h"
#include "utils/cublas_wrapper.h"
template <typename scalar_t>
void fmoe_cuda_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
const long* expert_count,
scalar_t* output_buf,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count[i], in_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
input_buf + ptr * in_feat, in_feat,
&beta,
output_buf + out_feat * ptr, out_feat
));
ptr += expert_count[i];
}
smgr->sync(num_expert);
}
template <typename scalar_t>
void fmoe_cuda_backward_impl(
const scalar_t* grad_output_buf,
const scalar_t* input_buf,
const scalar_t* weight,
const long* expert_count,
scalar_t* grad_input_buf,
scalar_t* grad_weight,
const size_t batch_size,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
cudaMemset(grad_weight + i * in_feat * out_feat, 0,
sizeof(scalar_t) * in_feat * out_feat);
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
// Backward input: g_i = w @ g_o
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N,
CUBLAS_OP_N,
in_feat, expert_count[i], out_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
grad_input_buf + in_feat * ptr, in_feat
));
// Backward weight: g_w = i @ g_o
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_N,
CUBLAS_OP_T,
in_feat, out_feat, expert_count[i],
&alpha,
input_buf + in_feat * ptr, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
grad_weight + i * in_feat * out_feat, in_feat
));
ptr += expert_count[i];
}
smgr->sync(num_expert);
}
#include <unordered_map>
#include <mutex>
#include <cassert>
#include <thread>
#include <iostream>
#include "stream_manager.h"
#define SMGR_N_STREAMS 16
cudaStream_t CudaStreamManager::stream(size_t idx) {
return this->streams[idx % SMGR_N_STREAMS];
}
cublasHandle_t CudaStreamManager::handle(size_t idx) {
return this->handles[idx % SMGR_N_STREAMS];
}
void CudaStreamManager::sync(int idx) {
for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]);
}
}
void CudaStreamManager::setup(const int device) {
#ifdef MOE_USE_NCCL
this->ncclgood = 0;
#endif
this->device = device;
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS];
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
}
void CudaStreamManager::destroy() {
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i]));
}
delete[] streams;
delete[] handles;
}
std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;
CudaStreamManager* getCudaStreamManager(const int device) {
auto it = smgrs.find(device);
if (it == smgrs.end()) {
smgr_mtx.lock();
it = smgrs.find(device);
if (it == smgrs.end()) {
auto smgr = new CudaStreamManager(device);
smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
smgr_mtx.unlock();
return smgr;
} else {
smgr_mtx.unlock();
}
}
return it->second;
}
#ifndef CUDA_STREAM_MANAGER_H
#define CUDA_STREAM_MANAGER_H
#include "helper_cuda.h"
#include "utils/helper_cuda.h"
#ifdef MOE_USE_NCCL
#ifdef FMOE_USE_NCCL
#include <nccl.h>
#define NCCL_SAFE_CALL(__fn__) { \
auto __res__ = __fn__; \
if (__res__ != ncclSuccess) { \
fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
exit(-1); \
} \
auto __res__ = __fn__; \
if (__res__ != ncclSuccess) { \
fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
exit(-1); \
} \
}
#endif
......@@ -21,25 +21,25 @@ public:
int device;
cublasHandle_t* handles;
cudaStream_t* streams;
#ifdef MOE_USE_NCCL
char ncclgood;
ncclComm_t ncclcomm;
#ifdef FMOE_USE_NCCL
char ncclgood;
ncclComm_t ncclcomm;
#endif
public:
CudaStreamManager(int device_): device(device_) {
this->setup(device);
this->setup(device);
}
void setup(int);
void sync(int=0);
void destroy();
void setup(int);
void sync(int=0);
void destroy();
cudaStream_t stream(size_t=0);
cublasHandle_t handle(size_t=0);
cudaStream_t stream(size_t=0);
cublasHandle_t handle(size_t=0);
~CudaStreamManager() {
this->destroy();
this->destroy();
}
};
......
......@@ -85,11 +85,11 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
const c10::Half *beta,
c10::Half *C, int ldc) {
return cublasHgemm(handle, transa, transb, m, n, k,
(const __half*)alpha,
(const __half*)A, lda,
(const __half*)B, ldb,
(const __half*)beta,
(__half*)C, ldc);
(const __half*)alpha,
(const __half*)A, lda,
(const __half*)B, ldb,
(const __half*)beta,
(__half*)C, ldc);
}
#endif // CUBLAS_WRAPPER_H
#ifndef FMOE_UTILS_H
#define FMOE_UTILS_H
#define CHECK_CUDA(x) AT_ASSERTM(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#endif // FMOE_UTILS_H
......@@ -31,6 +31,7 @@
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <stdio.h>
#ifndef HELPER_CUDA_H
#define HELPER_CUDA_H
......
......@@ -4,8 +4,8 @@
#include <chrono>
inline double getDuration(std::chrono::time_point<std::chrono::system_clock> a,
std::chrono::time_point<std::chrono::system_clock> b) {
return std::chrono::duration<double>(b - a).count();
std::chrono::time_point<std::chrono::system_clock> b) {
return std::chrono::duration<double>(b - a).count();
}
#define timestamp(__var__) auto __var__ = std::chrono::system_clock::now();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment