Unverified Commit baae8fb9 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #31 from laekov/gate

Reconstruct gate and add gshard / switch
parents 3c42c892 8d14dd29
#include "balancing.cuh"
#include <torch/extension.h>
/*
* note that due to limit of cuda atomic operator, capacity should be int32
*/
std::vector<torch::Tensor> _limit_by_capacity(
torch::Tensor expert_count, torch::Tensor capacity,
long n_expert, long n_worker) {
CHECK_INPUT(expert_count);
CHECK_INPUT(capacity);
auto expert_count_ack = torch::empty_like(expert_count);
auto smgr = getCudaStreamManager(expert_count.device().index());
fmoe_cuda_limit_by_capacity_impl(
expert_count.data_ptr<long>(),
capacity.data_ptr<int>(),
expert_count_ack.data_ptr<long>(),
n_expert, n_worker, smgr);
return {expert_count_ack};
}
void _prune_gate_by_capacity(
torch::Tensor gate_idx, torch::Tensor expert_count,
long n_expert, long n_worker) {
auto smgr = getCudaStreamManager(expert_count.device().index());
auto batch_size = gate_idx.numel();
fmoe_cuda_prune_gate_by_capacity_impl(
gate_idx.data_ptr<long>(),
expert_count.data_ptr<int>(),
batch_size, n_expert, n_worker, smgr);
}
#include "stream_manager.h"
#include "utils/fmoe_utils.h"
#include <cuda.h>
__global__
void limit_by_capacity_kernel(const long* ec, int* cap, long* eca,
const long n_expert, const long n_worker) {
int eid = blockIdx.y;
int wid = blockIdx.x * blockDim.x + threadIdx.x;
if (wid < n_worker) {
int proposal = ec[wid * n_expert + eid];
int cap_left = atomicSub(cap + eid, proposal);
if (cap_left >= proposal) {
eca[wid * n_expert + eid] = proposal;
} else if (cap_left >= 0) {
eca[wid * n_expert + eid] = cap_left;
} else {
eca[wid * n_expert + eid] = 0;
}
}
}
void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
long* eca, const long n_expert, const long n_worker,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(n_worker, 1024), n_expert);
dim3 block_dim(1024);
limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
ec, cap, eca, n_expert, n_worker);
smgr->sync(1);
}
__global__
void prune_gate_by_capacity_kernel(long* gate_idx, int* ec,
const long batch_size, const long n_expert, const long n_worker) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < batch_size) {
int orig_cap = atomicSub(ec + gate_idx[i], 1);
if (orig_cap <= 0) {
gate_idx[i] = -1;
}
}
}
void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, int* ec,
const long batch_size, const long n_expert, const long n_worker,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(batch_size, 1024));
dim3 block_dim(1024);
prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
gate_idx, ec, batch_size, n_expert, n_worker
);
smgr->sync(1);
}
#include <unordered_map>
#include <mutex>
#include <cassert>
#include <thread>
#include <iostream>
#include "cuda_stream_manager.h"
#define SMGR_N_STREAMS 16
cudaStream_t CudaStreamManager::stream(size_t idx) {
return this->streams[idx % SMGR_N_STREAMS];
}
cublasHandle_t CudaStreamManager::handle(size_t idx) {
return this->handles[idx % SMGR_N_STREAMS];
}
void CudaStreamManager::sync(int idx) {
for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]);
}
}
void CudaStreamManager::setup(const int device) {
#ifdef MOE_USE_NCCL
this->ncclgood = 0;
#endif
this->device = device;
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS];
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
}
void CudaStreamManager::destroy() {
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i]));
}
delete[] streams;
delete[] handles;
}
std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;
CudaStreamManager* getCudaStreamManager(const int device) {
auto it = smgrs.find(device);
if (it == smgrs.end()) {
smgr_mtx.lock();
it = smgrs.find(device);
if (it == smgrs.end()) {
auto smgr = new CudaStreamManager(device);
smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
smgr_mtx.unlock();
return smgr;
} else {
smgr_mtx.unlock();
}
}
return it->second;
}
#include <iostream>
#include <vector>
#include <torch/extension.h>
// global_exchange
#ifdef FMOE_USE_NCCL
#include <c10d/ProcessGroupNCCL.hpp>
std::vector<torch::Tensor> _expert_exchange(
torch::Tensor local_expert_count,
long n_expert, long n_workers);
std::vector<torch::Tensor> _global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
std::vector<torch::Tensor> _global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t);
#endif // FMOE_USE_NCCL
// local_exchange
void _assign_pos(
torch::Tensor cum_count,
torch::Tensor gate,
torch::Tensor pos);
// parallel_linear
std::vector<torch::Tensor> _linear_forward(
torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight,
at::optional<torch::Tensor> bias
);
std::vector<torch::Tensor> _linear_backward(
torch::Tensor grad_output_buf,
torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight,
at::optional<torch::Tensor> bias
);
// balancing
std::vector<torch::Tensor> _limit_by_capacity(
torch::Tensor expert_count, torch::Tensor capacity,
long n_expert, long n_experts);
void _prune_gate_by_capacity(
torch::Tensor gate_idx, torch::Tensor expert_count,
long n_expert, long n_worker);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef FMOE_USE_NCCL
m.def("expert_exchange", &_expert_exchange, "FastMoE expert exchange (CUDA)");
m.def("global_scatter", &_global_scatter, "FastMoE global scatter (CUDA)");
m.def("global_gather", &_global_gather, "FastMoE global gather (CUDA)");
m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm");
#endif
m.def("assign_pos_", &_assign_pos, "FastMoE assign pos by gate(CUDA)");
m.def("linear_forward", &_linear_forward, "FastMoE forward (CUDA)");
m.def("linear_backward", &_linear_backward, "FastMoE backward (CUDA)");
m.def("limit_by_capacity", &_limit_by_capacity, "FastMoE limit experts by capacity(CUDA)");
m.def("prune_gate_by_capacity", &_prune_gate_by_capacity, "FastMoE prune gate by capacity(CUDA)");
}
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#include "cublas_wrapper.h"
#ifdef FMOE_USE_NCCL
#include <nccl.h>
template<typename scalar_t>
void moe_cuda_global_fused_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
scalar_t* global_input_buf,
scalar_t* global_output_buf,
scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
long in_feat, long out_feat,
long num_expert, long world_size,
CudaStreamManager* smgr) {
int ptr = 0;
int send_ptr = 0;
int recv_ptr = 0;
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
scalar_t alpha = 1, beta = 0;
for (int i = 0; i < num_expert; ++i) {
int expert_count = 0;
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
global_input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
recv_ptr += global_expert_count[idx];
expert_count += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count, in_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
global_input_buf + ptr * in_feat, in_feat,
&beta,
global_output_buf + out_feat * ptr, out_feat
));
ptr += expert_count;
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
global_output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(num_expert);
}
std::vector<torch::Tensor> moe_cuda_global_fused_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers) {
const auto num_expert = local_expert_count.size(0) / n_workers;
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
auto smgr = getCudaStreamManager(input_buf.device().index());
auto global_input_buf = input_buf.new_empty({global_batch_size, in_feat});
auto global_output_buf = input_buf.new_empty({global_batch_size, out_feat});
auto output_buf = input_buf.new_empty({local_batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"moe_cuda_global_fused_forward", ([&] {
moe_cuda_global_fused_forward_impl(
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
global_input_buf.data_ptr<scalar_t>(),
global_output_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
in_feat, out_feat, num_expert, n_workers,
smgr);
}));
return {output_buf, global_input_buf};
}
#endif
#include "global_exchange.h"
#include "utils/fmoe_utils.h"
#include <torch/extension.h>
#ifdef FMOE_USE_NCCL
#include <nccl.h>
std::vector<torch::Tensor> _expert_exchange(
torch::Tensor local_expert_count,
long n_expert, long n_workers) {
auto global_expert_count = torch::empty_like(local_expert_count);
auto smgr = getCudaStreamManager(local_expert_count.device().index());
fmoe_cuda_expert_exchange_impl(
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
n_expert, n_workers,
smgr);
return {global_expert_count};
}
std::vector<torch::Tensor> _global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
CHECK_INPUT(input_buf);
auto n_expert = local_expert_count.size(0) / n_workers;
auto in_feat = input_buf.size(1);
auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_global_scatter", ([&] {
fmoe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
global_input_buf.data_ptr<scalar_t>(),
in_feat, n_expert, n_workers,
smgr
);
}));
return {global_input_buf,};
}
std::vector<torch::Tensor> _global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
CHECK_INPUT(output_buf);
auto n_expert = local_expert_count.size(0) / n_workers;
auto out_feat = output_buf.size(1);
auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(),
"fmoe_cuda_global_gather", ([&] {
fmoe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
local_output_buf.data_ptr<scalar_t>(),
out_feat, n_expert, n_workers,
smgr
);
}));
return {local_output_buf,};
}
#include <c10d/ProcessGroupNCCL.hpp>
class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
ncclComm_t getcomm(at::Device dev) {
ncclUniqueId ncclID;
int rank = getRank();
if (rank == 0) {
ncclGetUniqueId(&ncclID);
}
broadcastUniqueNCCLID(&ncclID,
c10d::OpType::SEND,
"fastmoe_nccl_comm",
rank);
ncclComm_t comm;
ncclCommInitRank(&comm, getSize(), ncclID, rank);
return comm;
}
};
void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
auto smgr = getCudaStreamManager(t.device().index());
if (smgr->ncclgood) {
return;
}
HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
smgr->ncclcomm = h->getcomm(t.device());
if (smgr->ncclcomm != 0) {
smgr->ncclgood = 1;
} else {
std::cerr << "Nccl initialization failed\n";
}
}
#endif // FMOE_USE_NCCL
#include "stream_manager.h"
#ifdef FMOE_USE_NCCL
void fmoe_cuda_expert_exchange_impl(
const long* local_expert_count,
long* global_expert_count,
int n_expert, int world_size,
CudaStreamManager* smgr) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int i = 0; i < world_size; ++i) {
NCCL_SAFE_CALL(ncclSend(
local_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
}
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
}
template<typename scalar_t>
void fmoe_cuda_global_scatter_impl(
const scalar_t* local_input_buf,
const long* local_expert_count,
const long* global_expert_count,
scalar_t* input_buf,
size_t in_feat, size_t n_expert, size_t world_size,
CudaStreamManager* smgr) {
// assert world_size > 1
int recv_ptr = 0;
/* TODO: may save for backward */
long*expert_ptr = new long[n_expert * world_size];
expert_ptr[0] = 0;
for (size_t i = 1; i < n_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (size_t i = 0; i < n_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (size_t j = 0; j < world_size; ++j) {
int idx = i + j * n_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
local_input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
recv_ptr += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
template<typename scalar_t>
void fmoe_cuda_global_gather_impl(
const scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
scalar_t* local_output_buf,
size_t out_feat, size_t n_expert, size_t world_size,
CudaStreamManager* smgr) {
long send_ptr = 0;
/* TODO: may save for backward */
long *expert_ptr = new long[n_expert * world_size];
expert_ptr[0] = 0;
for (size_t i = 1; i < n_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (size_t i = 0; i < n_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (size_t j = 0; j < world_size; ++j) {
int idx = i + j * n_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
#endif // FMOE_USE_NCCL
#include "local_exchange.cuh"
#include "utils/fmoe_utils.h"
#include <torch/extension.h>
void _assign_pos(
torch::Tensor cum_count,
torch::Tensor gate,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(cum_count.device().index());
auto gate_shp = gate.sizes();
size_t batch_size = gate_shp[0], topk = 1;
if (gate_shp.size() == 2) {
topk = gate_shp[1];
}
fmoe_cuda_assign_pos_impl(
cum_count.data_ptr<int>(),
gate.data_ptr<long>(),
pos.data_ptr<long>(),
batch_size, topk, smgr);
}
#include "stream_manager.h"
#include "utils/helper_cuda.h"
#include "utils/fmoe_utils.h"
__global__
void assign_pos_kernel(int* cum_count, const long* gate, long* pos,
size_t numel, size_t topk) {
size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < numel) {
long gate_idx = gate[idx];
if (gate_idx > -1) {
int p = atomicSub(cum_count + gate_idx, 1);
pos[p - 1] = (long)idx;
}
}
}
void fmoe_cuda_assign_pos_impl(
int* cum_count, const long* gate, long* pos,
const size_t batch_size, const size_t topk,
CudaStreamManager* smgr) {
size_t numel = batch_size * topk;
assign_pos_kernel
<<<CEIL(numel, 256), 256, 0, smgr->stream(0)>>>
(cum_count, gate, pos, numel, topk);
smgr->sync(1);
}
#include <torch/extension.h>
#include <cstdio>
#include <iostream>
#include <vector>
#include "moe_cuda_kernel.h"
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe_expert_count(
torch::Tensor gate,
size_t num_expert) {
CHECK_INPUT(gate);
return moe_cuda_expert_count(gate, num_expert);
}
std::vector<torch::Tensor> moe_local_scatter(
torch::Tensor input,
torch::Tensor pos) {
CHECK_INPUT(input);
return moe_cuda_local_scatter(input, pos);
}
std::vector<torch::Tensor> moe_local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
CHECK_INPUT(output_buf);
return moe_cuda_local_gather(output_buf, pos);
}
std::vector<torch::Tensor> moe_forward(
torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor expert_count, // [num_expert]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) {
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
// check if bias is valid in case it exists
if (bias_o.has_value()) {
auto bias = bias_o.value();
CHECK_INPUT(bias);
}
return moe_cuda_forward(input_buf, expert_count, weight, bias_o);
}
std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor expert_count, // [num_expert]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) {
CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
// check if bias is valid in case it exists
if (bias_o.has_value()) {
auto bias = bias_o.value();
CHECK_INPUT(bias);
}
return moe_cuda_backward(grad_output_buf, input_buf, expert_count, weight, bias_o);
}
#ifdef MOE_USE_NCCL
std::vector<torch::Tensor> moe_expert_exchange(
torch::Tensor local_expert_count,
size_t num_expert, size_t n_workers) {
return moe_cuda_expert_exchange(local_expert_count, num_expert, n_workers);
}
std::vector<torch::Tensor> moe_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
size_t batch_size, size_t n_workers) {
CHECK_INPUT(input_buf);
return moe_cuda_global_scatter(input_buf,
local_expert_count, global_expert_count,
batch_size, n_workers);
}
std::vector<torch::Tensor> moe_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
size_t batch_size, size_t n_workers) {
CHECK_INPUT(output_buf);
return moe_cuda_global_gather(output_buf,
local_expert_count, global_expert_count,
batch_size, n_workers);
}
std::vector<torch::Tensor> moe_global_fused_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers) {
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
return moe_cuda_global_fused_forward(
input_buf, weight, local_expert_count, global_expert_count,
global_batch_size, local_batch_size, n_workers);
}
#include <c10d/ProcessGroupNCCL.hpp>
#include "cuda_stream_manager.h"
class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
ncclComm_t getcomm(at::Device dev) {
auto key = std::to_string(dev.index());
#ifdef ENABLE_NCCL_P2P_SUPPORT
ncclUniqueId ncclID;
int rank = getRank();
if (rank == 0) {
ncclGetUniqueId(&ncclID);
}
broadcastUniqueNCCLID(&ncclID,
c10d::OpType::SEND,
"fastmoe_nccl_comm",
rank);
ncclComm_t comm;
ncclCommInitRank(&comm, getSize(), ncclID, rank);
return comm;
#else
auto v = getNCCLComm(key, {dev});
if (v.size() == 0) {
std::cerr << "PyTorch has nothing\n";
return 0;
}
int count;
ncclCommCount(v[0]->getNcclComm(), &count);
std::cerr << "PyTorch has " << v.size() << " comms, comm 0 size " << count << "\n";
return v[0]->getNcclComm();
#endif
}
};
void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
auto smgr = getCudaStreamManager(t.device().index());
if (smgr->ncclgood) {
return;
}
HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
smgr->ncclcomm = h->getcomm(t.device());
if (smgr->ncclcomm != 0) {
smgr->ncclgood = 1;
} else {
std::cerr << "Nccl initialization failed\n";
}
}
#endif // MOE_USE_NCCL
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("expert_count", &moe_expert_count, "MoE expert count (CUDA)");
m.def("local_scatter", &moe_local_scatter, "MoE local scatter (CUDA)");
m.def("local_gather", &moe_local_gather, "MoE local gather (CUDA)");
#ifdef MOE_USE_NCCL
m.def("expert_exchange", &moe_expert_exchange, "MoE expert exchange (CUDA)");
m.def("global_scatter", &moe_global_scatter, "MoE global scatter (CUDA)");
m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)");
m.def("global_fused_forward", &moe_global_fused_forward,
"MoE global gather (CUDA)");
m.def("ensure_nccl", &moe_ensure_nccl, "MoE ensure torch nccl comm");
#endif
m.def("forward", &moe_forward, "MoE forward (CUDA)");
m.def("backward", &moe_backward, "MoE backward (CUDA)");
}
\ No newline at end of file
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include "cuda_stream_manager.h"
#ifdef MOE_USE_NCCL
#include <nccl.h>
void moe_cuda_expert_exchange_impl(
const long* local_expert_count,
long* global_expert_count,
int num_expert, int world_size,
CudaStreamManager* smgr) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int i = 0; i < world_size; ++i) {
NCCL_SAFE_CALL(ncclSend(
local_expert_count + num_expert * i,
num_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + num_expert * i,
num_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
}
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers) {
auto global_expert_count = torch::empty_like(local_expert_count);
auto smgr = getCudaStreamManager(local_expert_count.device().index());
moe_cuda_expert_exchange_impl(
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
num_expert, n_workers,
smgr);
return {global_expert_count};
}
template<typename scalar_t>
void moe_cuda_global_scatter_impl(
const scalar_t* local_input_buf,
const long* local_expert_count,
const long* global_expert_count,
scalar_t* input_buf,
size_t in_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
// assert world_size > 1
int recv_ptr = 0;
/* TODO: may save for backward */
long*expert_ptr = new long[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
local_input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
recv_ptr += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
auto num_expert = local_expert_count.size(0) / n_workers;
auto in_feat = input_buf.size(1);
auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"moe_cuda_global_scatter", ([&] {
moe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
global_input_buf.data_ptr<scalar_t>(),
in_feat, num_expert, n_workers,
smgr
);
}));
return {global_input_buf,};
}
template<typename scalar_t>
void moe_cuda_global_gather_impl(
const scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
scalar_t* local_output_buf,
size_t out_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
long send_ptr = 0;
/* TODO: may save for backward */
long *expert_ptr = new long[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
auto num_expert = local_expert_count.size(0) / n_workers;
auto out_feat = output_buf.size(1);
auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(),
"moe_cuda_global_gather", ([&] {
moe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
local_output_buf.data_ptr<scalar_t>(),
out_feat, num_expert, n_workers,
smgr
);
}));
return {local_output_buf,};
}
#endif
#ifndef MOE_CUDA_KERNEL_H
#define MOE_CUDA_KERNEL_H
#include <vector>
#include <torch/extension.h>
#include <torch/torch.h>
#include "helper_cuda.h"
std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate, size_t num_expert);
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos);
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight,
at::optional<torch::Tensor> bias);
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf,
torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight,
at::optional<torch::Tensor> bias);
#ifdef MOE_USE_NCCL
std::vector<torch::Tensor> moe_cuda_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers);
std::vector<torch::Tensor> moe_cuda_global_fused_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers);
#endif
#endif // MOE_CUDA_KERNEL_H
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#include "cublas_wrapper.h"
#ifdef MOE_USE_NCCL
#include <nccl.h>
template<typename scalar_t>
void moe_cuda_global_fused_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
scalar_t* global_input_buf,
scalar_t* global_output_buf,
scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
long in_feat, long out_feat,
long num_expert, long world_size,
CudaStreamManager* smgr) {
int ptr = 0;
int send_ptr = 0;
int recv_ptr = 0;
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
scalar_t alpha = 1, beta = 0;
for (int i = 0; i < num_expert; ++i) {
int expert_count = 0;
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
global_input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
recv_ptr += global_expert_count[idx];
expert_count += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count, in_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
global_input_buf + ptr * in_feat, in_feat,
&beta,
global_output_buf + out_feat * ptr, out_feat
));
ptr += expert_count;
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
global_output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(num_expert);
}
std::vector<torch::Tensor> moe_cuda_global_fused_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers) {
const auto num_expert = local_expert_count.size(0) / n_workers;
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
auto smgr = getCudaStreamManager(input_buf.device().index());
auto global_input_buf = input_buf.new_empty({global_batch_size, in_feat});
auto global_output_buf = input_buf.new_empty({global_batch_size, out_feat});
auto output_buf = input_buf.new_empty({local_batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"moe_cuda_global_fused_forward", ([&] {
moe_cuda_global_fused_forward_impl(
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
global_input_buf.data_ptr<scalar_t>(),
global_output_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
in_feat, out_feat, num_expert, n_workers,
smgr);
}));
return {output_buf, global_input_buf};
}
#endif
#include "parallel_linear.cuh"
#include "utils/fmoe_utils.h"
#include <torch/extension.h>
std::vector<torch::Tensor> _linear_forward(
torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight,
at::optional<torch::Tensor> bias
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat);
#endif
torch::Tensor output;
if (bias.has_value()) {
output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0);
} else{
auto out_options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
output = torch::empty({batch_size, out_feat}, out_options);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda",
([&] {
fmoe_cuda_linear_forward_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(),
output.data_ptr<scalar_t>(),
bias.has_value(),
in_feat,
out_feat,
num_expert,
smgr
);
}));
return {output, };
}
std::vector<torch::Tensor> _linear_backward(
torch::Tensor grad_output_buf,
torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight,
at::optional<torch::Tensor> bias
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld\n",
batch_size, num_expert, in_feat, out_feat);
#endif
auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
fmoe_cuda_linear_backward_impl<scalar_t>(
grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(),
grad_input_buf.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
grad_bias.data_ptr<scalar_t>(),
bias.has_value(),
batch_size,
in_feat,
out_feat,
num_expert,
smgr
);
}));
return {grad_input_buf, grad_weight, grad_bias};
}
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <c10/cuda/CUDAGuard.h>
#include "timer.hh"
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
template <typename scalar_t>
__global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
const long* offset, const scalar_t** ptrs) {
size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < n) {
ptrs[idx] = base + stride * offset[idx];
}
}
template <typename scalar_t>
__global__
void batch_scatter_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
#include "stream_manager.h"
#include "utils/cublas_wrapper.h"
/*
......@@ -88,84 +52,8 @@ void column_reduce(const scalar_t * matrix, scalar_t * result,
}
void moe_cuda_expert_count_impl(
const int* d_gate,
int* expert_count,
int* d_pos,
const size_t num_expert,
const size_t batch_size) {
int *gate = new int[batch_size];
int *expert_ptr = new int[num_expert];
memset(expert_count, 0, sizeof(int) * num_expert);
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost));
for (int i = 0; i < batch_size; ++i) {
++expert_count[gate[i]];
}
expert_ptr[0] = 0;
for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
}
int *pos = new int[batch_size];
for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++;
}
for (int i = num_expert - 1; i > 0; --i) {
expert_ptr[i] = expert_ptr[i - 1];
}
expert_ptr[0] = 0;
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice));
delete [] gate;
delete [] expert_ptr;
}
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
const scalar_t* input,
const long* d_pos,
scalar_t* input_buf,
const long batch_size,
const long in_feat,
CudaStreamManager* smgr) {
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
input_buf);
smgr->sync(1);
}
template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x];
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
void moe_cuda_local_gather_impl(
const scalar_t* output_buf,
const long* d_pos,
scalar_t* output,
const size_t batch_size,
const size_t out_feat,
CudaStreamManager* smgr) {
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
output);
smgr->sync(1);
}
template <typename scalar_t>
void moe_cuda_forward_impl(
void fmoe_cuda_linear_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
const long* expert_count,
......@@ -200,7 +88,7 @@ void moe_cuda_forward_impl(
}
template <typename scalar_t>
void moe_cuda_backward_impl(
void fmoe_cuda_linear_backward_impl(
const scalar_t* grad_output_buf,
const scalar_t* input_buf,
const scalar_t* weight,
......@@ -272,165 +160,3 @@ void moe_cuda_backward_impl(
smgr->sync(num_expert);
}
std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate,
size_t num_expert) {
const auto batch_size = gate.size(0);
auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
auto expert_count = torch::empty(num_expert, ec_options);
auto pos_options = torch::TensorOptions()
.device(gate.device())
.dtype(torch::kInt32);
auto pos = torch::empty(batch_size, pos_options);
moe_cuda_expert_count_impl(
gate.data_ptr<int>(),
expert_count.data_ptr<int>(),
pos.data_ptr<int>(),
num_expert,
batch_size);
return {expert_count, pos};
}
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(input.device().index());
const auto batch_size = pos.size(0);
const auto in_feat = input.size(1);
auto opt = torch::TensorOptions()
.dtype(input.dtype())
.device(input.device());
auto input_buf = torch::empty({batch_size, in_feat}, opt);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda",
([&] {
moe_cuda_local_scatter_impl<scalar_t>(
input.data_ptr<scalar_t>(),
pos.data_ptr<long>(),
input_buf.data_ptr<scalar_t>(),
batch_size,
in_feat,
smgr);
}));
return {input_buf,};
}
std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(output_buf.device().index());
const auto batch_size = pos.size(0);
const auto out_feat = output_buf.size(1);
auto opt = torch::TensorOptions()
.dtype(output_buf.dtype())
.device(output_buf.device());
auto output = torch::empty({batch_size, out_feat}, opt);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda",
([&] {
moe_cuda_local_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
pos.data_ptr<long>(),
output.data_ptr<scalar_t>(),
batch_size,
out_feat,
smgr);
}));
return {output,};
}
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight,
at::optional<torch::Tensor> bias
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat);
#endif
torch::Tensor output;
if (bias.has_value()) {
output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0);
} else{
auto out_options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
output = torch::empty({batch_size, out_feat}, out_options);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda",
([&] {
moe_cuda_forward_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(),
output.data_ptr<scalar_t>(),
bias.has_value(),
in_feat,
out_feat,
num_expert,
smgr
);
}));
return {output, };
}
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor expert_count,
torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld\n",
batch_size, num_expert, in_feat, out_feat);
#endif
auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_backward_impl<scalar_t>(
grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(),
grad_input_buf.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
grad_bias.data_ptr<scalar_t>(),
bias.has_value(),
batch_size,
in_feat,
out_feat,
num_expert,
smgr
);
}));
return {grad_input_buf, grad_weight, grad_bias};
}
#include <unordered_map>
#include <mutex>
#include <cassert>
#include <thread>
#include <iostream>
#include "stream_manager.h"
#define SMGR_N_STREAMS 16
cudaStream_t CudaStreamManager::stream(size_t idx) {
return this->streams[idx % SMGR_N_STREAMS];
}
cublasHandle_t CudaStreamManager::handle(size_t idx) {
return this->handles[idx % SMGR_N_STREAMS];
}
void CudaStreamManager::sync(int idx) {
for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]);
}
}
void CudaStreamManager::setup(const int device) {
#ifdef FMOE_USE_NCCL
this->ncclgood = 0;
#endif
this->device = device;
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS];
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
}
void CudaStreamManager::destroy() {
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i]));
}
delete[] streams;
delete[] handles;
}
std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;
CudaStreamManager* getCudaStreamManager(const int device) {
auto it = smgrs.find(device);
if (it == smgrs.end()) {
smgr_mtx.lock();
it = smgrs.find(device);
if (it == smgrs.end()) {
auto smgr = new CudaStreamManager(device);
smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
smgr_mtx.unlock();
return smgr;
} else {
smgr_mtx.unlock();
}
}
return it->second;
}
#ifndef CUDA_STREAM_MANAGER_H
#define CUDA_STREAM_MANAGER_H
#include "helper_cuda.h"
#include "utils/helper_cuda.h"
#ifdef MOE_USE_NCCL
#ifdef FMOE_USE_NCCL
#include <nccl.h>
#define NCCL_SAFE_CALL(__fn__) { \
auto __res__ = __fn__; \
if (__res__ != ncclSuccess) { \
fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
exit(-1); \
} \
auto __res__ = __fn__; \
if (__res__ != ncclSuccess) { \
fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
exit(-1); \
} \
}
#endif
......@@ -21,25 +21,25 @@ public:
int device;
cublasHandle_t* handles;
cudaStream_t* streams;
#ifdef MOE_USE_NCCL
char ncclgood;
ncclComm_t ncclcomm;
#ifdef FMOE_USE_NCCL
char ncclgood;
ncclComm_t ncclcomm;
#endif
public:
CudaStreamManager(int device_): device(device_) {
this->setup(device);
this->setup(device);
}
void setup(int);
void sync(int=0);
void destroy();
void setup(int);
void sync(int=0);
void destroy();
cudaStream_t stream(size_t=0);
cublasHandle_t handle(size_t=0);
cudaStream_t stream(size_t=0);
cublasHandle_t handle(size_t=0);
~CudaStreamManager() {
this->destroy();
this->destroy();
}
};
......
default : test_prune_gate test_limit
test_% : %.cu
nvcc $< ../stream_manager.cpp -lcublas -o $@
#include "../local_exchange.cuh"
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cuda.h>
#include <cuda_runtime.h>
int main(int argc, char* args[]) {
int n_worker = atoi(args[1]);
int n_expert = atoi(args[2]);
int batch_size = atoi(args[3]);
int topk = atoi(args[4]);
int tot_expert = n_worker * n_expert;
long* gate_idx = new long[batch_size * topk];
long* n_gate_idx = new long[batch_size * topk];
int* lec = new int[tot_expert];
memset(lec, 0, sizeof(int) * tot_expert);
for (int i = 0; i < batch_size * topk; ++i) {
if (rand() % 10) {
gate_idx[i] = rand() % tot_expert;
++lec[gate_idx[i]];
} else {
gate_idx[i] = -1;
}
}
for (int i = 1; i < tot_expert; ++i) {
lec[i] += lec[i - 1];
}
puts("gate idx");
for (int i = 0; i < batch_size * topk; ++i) {
printf("%d ", gate_idx[i]);
}
putchar(10);
int nlec = lec[tot_expert - 1];
int* g_lec;
cudaMalloc(&g_lec, sizeof(int) * tot_expert);
cudaMemcpy(g_lec, lec, sizeof(int) * tot_expert, cudaMemcpyHostToDevice);
long* g_gate_idx;
cudaMalloc(&g_gate_idx, sizeof(long) * batch_size * topk);
cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * batch_size * topk,
cudaMemcpyHostToDevice);
long* g_pos;
cudaMalloc(&g_pos, sizeof(long) * nlec);
// cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * nlec, cudaMemcpyHostToDevice);
auto smgr = getCudaStreamManager(0);
fmoe_cuda_assign_pos_impl(g_lec, g_gate_idx, g_pos, batch_size * topk,
topk, smgr);
long* pos = new long[nlec];
cudaMemcpy(pos, g_pos, sizeof(long) * nlec, cudaMemcpyDeviceToHost);
puts("pos");
for (int i = 0; i < nlec; ++i) {
printf("%d ", pos[i]);
}
putchar(10);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment