Unverified Commit 295a615a authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Improve efficiency of metadata exchange (#48)

* use single variable instead of vector in c functions

* expert count kernel

* remove all lists

* fix old tests
parent b861e928
...@@ -5,15 +5,15 @@ ...@@ -5,15 +5,15 @@
// global_exchange // global_exchange
#ifdef FMOE_USE_NCCL #ifdef FMOE_USE_NCCL
#include <c10d/ProcessGroupNCCL.hpp> #include <c10d/ProcessGroupNCCL.hpp>
std::vector<torch::Tensor> _expert_exchange( torch::Tensor _expert_exchange(
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
long n_expert, long n_workers); long n_expert, long n_workers);
std::vector<torch::Tensor> _global_scatter( torch::Tensor _global_scatter(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
long batch_size, long n_workers); long batch_size, long n_workers);
std::vector<torch::Tensor> _global_gather( torch::Tensor _global_gather(
torch::Tensor output_buf, torch::Tensor output_buf,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
...@@ -26,9 +26,12 @@ void _assign_pos( ...@@ -26,9 +26,12 @@ void _assign_pos(
torch::Tensor cum_count, torch::Tensor cum_count,
torch::Tensor gate, torch::Tensor gate,
torch::Tensor pos); torch::Tensor pos);
void _expert_count(
torch::Tensor gate_idx,
torch::Tensor expert_count);
// parallel_linear // parallel_linear
std::vector<torch::Tensor> _linear_forward( torch::Tensor _linear_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count, torch::Tensor expert_count,
torch::Tensor weight, torch::Tensor weight,
...@@ -58,7 +61,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -58,7 +61,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm"); m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm");
#endif #endif
m.def("assign_pos_", &_assign_pos, "FastMoE assign pos by gate(CUDA)"); m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)");
m.def("assign_pos", &_assign_pos, "FastMoE assign pos by gate (CUDA)");
m.def("linear_forward", &_linear_forward, "FastMoE forward (CUDA)"); m.def("linear_forward", &_linear_forward, "FastMoE forward (CUDA)");
m.def("linear_backward", &_linear_backward, "FastMoE backward (CUDA)"); m.def("linear_backward", &_linear_backward, "FastMoE backward (CUDA)");
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#ifdef FMOE_USE_NCCL #ifdef FMOE_USE_NCCL
#include <nccl.h> #include <nccl.h>
std::vector<torch::Tensor> _expert_exchange( torch::Tensor _expert_exchange(
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
long n_expert, long n_workers) { long n_expert, long n_workers) {
auto global_expert_count = torch::empty_like(local_expert_count); auto global_expert_count = torch::empty_like(local_expert_count);
...@@ -16,10 +16,10 @@ std::vector<torch::Tensor> _expert_exchange( ...@@ -16,10 +16,10 @@ std::vector<torch::Tensor> _expert_exchange(
global_expert_count.data_ptr<long>(), global_expert_count.data_ptr<long>(),
n_expert, n_workers, n_expert, n_workers,
smgr); smgr);
return {global_expert_count}; return global_expert_count;
} }
std::vector<torch::Tensor> _global_scatter( torch::Tensor _global_scatter(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
...@@ -42,10 +42,10 @@ std::vector<torch::Tensor> _global_scatter( ...@@ -42,10 +42,10 @@ std::vector<torch::Tensor> _global_scatter(
smgr smgr
); );
})); }));
return {global_input_buf,}; return global_input_buf;
} }
std::vector<torch::Tensor> _global_gather( torch::Tensor _global_gather(
torch::Tensor output_buf, torch::Tensor output_buf,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
...@@ -68,7 +68,7 @@ std::vector<torch::Tensor> _global_gather( ...@@ -68,7 +68,7 @@ std::vector<torch::Tensor> _global_gather(
smgr smgr
); );
})); }));
return {local_output_buf,}; return local_output_buf;
} }
#include <c10d/ProcessGroupNCCL.hpp> #include <c10d/ProcessGroupNCCL.hpp>
...@@ -86,7 +86,7 @@ public: ...@@ -86,7 +86,7 @@ public:
"fastmoe_nccl_comm", "fastmoe_nccl_comm",
rank); rank);
ncclComm_t comm; ncclComm_t comm;
ncclCommInitRank(&comm, getSize(), ncclID, rank); NCCL_SAFE_CALL(ncclCommInitRank(&comm, getSize(), ncclID, rank));
return comm; return comm;
} }
}; };
......
...@@ -18,3 +18,15 @@ void _assign_pos( ...@@ -18,3 +18,15 @@ void _assign_pos(
pos.data_ptr<long>(), pos.data_ptr<long>(),
batch_size, topk, smgr); batch_size, topk, smgr);
} }
void _expert_count(
torch::Tensor gate_idx,
torch::Tensor expert_count) {
auto smgr = getCudaStreamManager(gate_idx.device().index());
auto batch_size = gate_idx.numel();
auto n_expert = expert_count.numel();
fmoe_cuda_expert_count_impl(
gate_idx.data_ptr<long>(),
expert_count.data_ptr<int>(),
batch_size, n_expert, smgr);
}
...@@ -25,3 +25,47 @@ void fmoe_cuda_assign_pos_impl( ...@@ -25,3 +25,47 @@ void fmoe_cuda_assign_pos_impl(
(cum_count, gate, pos, numel, topk); (cum_count, gate, pos, numel, topk);
smgr->sync(1); smgr->sync(1);
} }
#define PERTHREAD_EXPERTS 256
#define WARP_SIZE 32
__global__
void expert_count_kernel(const long* gate_idx, int* expert_count,
const size_t batch_size, const size_t n_expert) {
int res_tmp[PERTHREAD_EXPERTS] = {0};
long expert_min = blockIdx.x * PERTHREAD_EXPERTS;
long expert_max = expert_min + PERTHREAD_EXPERTS;
if (expert_max > n_expert) {
expert_max = n_expert;
}
for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
long idx = gate_idx[i];
if (idx == -1) {
continue;
}
if (idx < expert_min || idx >= expert_max) {
continue;
}
res_tmp[idx - expert_min] += 1;
}
for (int i = expert_min; i < expert_max; ++i) {
int x = res_tmp[i - expert_min];
#pragma unroll
for (int j = 1; j < WARP_SIZE; j <<= 1) {
x = x + __shfl_down_sync(-1u, x, j);
}
if (threadIdx.x % WARP_SIZE == 0) {
atomicAdd(expert_count + i, x);
}
}
}
void fmoe_cuda_expert_count_impl(
const long* gate_idx, int* expert_count,
const size_t batch_size, const size_t n_expert,
CudaStreamManager* smgr) {
expert_count_kernel
<<<CEIL(n_expert, PERTHREAD_EXPERTS), 256, 0, smgr->stream(0)>>>
(gate_idx, expert_count, batch_size, n_expert);
smgr->sync(1);
}
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "utils/fmoe_utils.h" #include "utils/fmoe_utils.h"
#include <torch/extension.h> #include <torch/extension.h>
std::vector<torch::Tensor> _linear_forward( torch::Tensor _linear_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count, torch::Tensor expert_count,
torch::Tensor weight, torch::Tensor weight,
...@@ -45,7 +45,7 @@ std::vector<torch::Tensor> _linear_forward( ...@@ -45,7 +45,7 @@ std::vector<torch::Tensor> _linear_forward(
); );
})); }));
return {output, }; return output;
} }
......
default : test_prune_gate test_limit default : test_prune_gate test_limit test_assign test_counting
test_% : %.cu test_% : %.cu
nvcc $< ../stream_manager.cpp -lcublas -o $@ nvcc $< ../stream_manager.cpp -lcublas -o $@
#include "../local_exchange.cuh"
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cuda.h>
#include <cuda_runtime.h>
int main(int argc, char* args[]) {
int batch_size = atoi(args[1]);
int n_expert = atoi(args[2]);
long* gate_idx = new long[batch_size];
long* n_gate_idx = new long[batch_size];
int* ref_lec = new int[n_expert];
memset(ref_lec, 0, sizeof(int) * n_expert);
for (int i = 0; i < batch_size; ++i) {
gate_idx[i] = rand() % (n_expert + 1) - 1;
if (gate_idx[i] != -1) {
ref_lec[gate_idx[i]] += 1;
}
}
puts("ref lec");
for (int i = 0; i < n_expert; ++i) {
printf("%d ", ref_lec[i]);
}
putchar(10);
int* g_lec;
cudaMalloc(&g_lec, sizeof(int) * n_expert);
cudaMemset(g_lec, 0, sizeof(int) * n_expert);
long* g_gate_idx;
cudaMalloc(&g_gate_idx, sizeof(long) * batch_size);
cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * batch_size,
cudaMemcpyHostToDevice);
auto smgr = getCudaStreamManager(0);
fmoe_cuda_expert_count_impl(g_gate_idx, g_lec, batch_size, n_expert, smgr);
int* lec = new int[n_expert];
cudaMemcpy(lec, g_lec, sizeof(int) * n_expert, cudaMemcpyDeviceToHost);
puts("lec");
for (int i = 0; i < n_expert; ++i) {
printf("%d ", lec[i]);
}
putchar(10);
}
...@@ -14,8 +14,9 @@ int main(int argc, char* args[]) { ...@@ -14,8 +14,9 @@ int main(int argc, char* args[]) {
long* gate_idx = new long[batch_size]; long* gate_idx = new long[batch_size];
long* n_gate_idx = new long[batch_size]; long* n_gate_idx = new long[batch_size];
int* lec = new int[tot_expert]; long* lec = new long[tot_expert];
memset(lec, 0, sizeof(int) * tot_expert); memset(lec, 0, sizeof(long) * tot_expert);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
gate_idx[i] = rand() % tot_expert; gate_idx[i] = rand() % tot_expert;
++lec[gate_idx[i]]; ++lec[gate_idx[i]];
...@@ -23,15 +24,19 @@ int main(int argc, char* args[]) { ...@@ -23,15 +24,19 @@ int main(int argc, char* args[]) {
for (int i = 0; i < tot_expert; ++i) { for (int i = 0; i < tot_expert; ++i) {
lec[i] >>= 1; lec[i] >>= 1;
} }
int* g_lec; long* g_lec;
cudaMalloc(&g_lec, sizeof(int) * tot_expert); cudaMalloc(&g_lec, sizeof(long) * tot_expert);
cudaMemcpy(g_lec, lec, sizeof(int) * tot_expert, cudaMemcpyHostToDevice); cudaMemcpy(g_lec, lec, sizeof(long) * tot_expert, cudaMemcpyHostToDevice);
int* g_new_lec;
cudaMalloc(&g_new_lec, sizeof(int) * tot_expert);
long* g_gate_idx; long* g_gate_idx;
cudaMalloc(&g_gate_idx, sizeof(long) * batch_size); cudaMalloc(&g_gate_idx, sizeof(long) * batch_size);
cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * batch_size, cudaMemcpyHostToDevice); cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * batch_size, cudaMemcpyHostToDevice);
auto smgr = getCudaStreamManager(0); auto smgr = getCudaStreamManager(0);
fmoe_cuda_prune_gate_by_capacity_impl(g_gate_idx, g_lec, fmoe_cuda_prune_gate_by_capacity_impl(g_gate_idx, g_lec, g_new_lec,
batch_size, n_expert, n_worker, smgr); batch_size, n_expert, n_worker, smgr);
cudaMemcpy(n_gate_idx, g_gate_idx, sizeof(long) * batch_size, cudaMemcpyDeviceToHost); cudaMemcpy(n_gate_idx, g_gate_idx, sizeof(long) * batch_size, cudaMemcpyDeviceToHost);
......
...@@ -18,19 +18,15 @@ def _ensure_nccl(t, comm=None): ...@@ -18,19 +18,15 @@ def _ensure_nccl(t, comm=None):
def count_by_gate(gate, num_expert, world_size, require_pos=True): def count_by_gate(gate, num_expert, world_size, require_pos=True):
with torch.no_grad(): with torch.no_grad():
flatten_gate = gate.view(-1)
eff_gate = flatten_gate[flatten_gate != -1]
local_expert_count = torch.zeros( local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.long num_expert * world_size, device=gate.device, dtype=torch.int32
) )
ones = torch.ones(eff_gate.numel(), fmoe_cuda.expert_count(gate, local_expert_count)
device=gate.device, dtype=torch.long) local_expert_count = local_expert_count.long()
local_expert_count.index_add_(0, eff_gate, ones)
if world_size > 1: if world_size > 1:
_ensure_nccl(gate) _ensure_nccl(gate)
(global_expert_count,) = fmoe_cuda.expert_exchange( global_expert_count = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size local_expert_count, num_expert, world_size
) )
else: else:
...@@ -41,7 +37,7 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True): ...@@ -41,7 +37,7 @@ def count_by_gate(gate, num_expert, world_size, require_pos=True):
lec_cum = torch.cumsum(local_expert_count, dim=0).int() lec_cum = torch.cumsum(local_expert_count, dim=0).int()
pos_size = lec_cum[-1].item() pos_size = lec_cum[-1].item()
pos = torch.empty((pos_size,), device=gate.device, dtype=torch.long) pos = torch.empty((pos_size,), device=gate.device, dtype=torch.long)
fmoe_cuda.assign_pos_(lec_cum, gate, pos) fmoe_cuda.assign_pos(lec_cum, gate, pos)
return pos, local_expert_count, global_expert_count return pos, local_expert_count, global_expert_count
...@@ -108,7 +104,7 @@ class MOEScatter(Function): ...@@ -108,7 +104,7 @@ class MOEScatter(Function):
): ):
local_input_buf = _local_scatter(inp, pos) local_input_buf = _local_scatter(inp, pos)
if world_size > 1: if world_size > 1:
(global_input_buf,) = fmoe_cuda.global_scatter( global_input_buf = fmoe_cuda.global_scatter(
local_input_buf, local_input_buf,
local_expert_count, local_expert_count,
global_expert_count, global_expert_count,
...@@ -128,7 +124,7 @@ class MOEScatter(Function): ...@@ -128,7 +124,7 @@ class MOEScatter(Function):
(inp_batch_size, buf_batch_size, world_size) = ctx.moe_args (inp_batch_size, buf_batch_size, world_size) = ctx.moe_args
if world_size > 1: if world_size > 1:
(local_grad_in,) = fmoe_cuda.global_gather( local_grad_in = fmoe_cuda.global_gather(
global_grad_in, global_grad_in,
local_expert_count, local_expert_count,
global_expert_count, global_expert_count,
...@@ -148,7 +144,7 @@ class MOELinear(Function): ...@@ -148,7 +144,7 @@ class MOELinear(Function):
@staticmethod @staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None): def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
(global_output_buf,) = fmoe_cuda.linear_forward( global_output_buf = fmoe_cuda.linear_forward(
global_input_buf, fwd_expert_count, weight, bias global_input_buf, fwd_expert_count, weight, bias
) )
variables = (global_input_buf, fwd_expert_count, weight, bias) variables = (global_input_buf, fwd_expert_count, weight, bias)
...@@ -185,7 +181,7 @@ class MOEGather(Function): ...@@ -185,7 +181,7 @@ class MOEGather(Function):
world_size, world_size,
): ):
if world_size > 1: if world_size > 1:
(local_output_buf,) = fmoe_cuda.global_gather( local_output_buf = fmoe_cuda.global_gather(
global_output_buf, global_output_buf,
local_expert_count, local_expert_count,
global_expert_count, global_expert_count,
...@@ -208,7 +204,7 @@ class MOEGather(Function): ...@@ -208,7 +204,7 @@ class MOEGather(Function):
fwd_batch_size, world_size = ctx.moe_args fwd_batch_size, world_size = ctx.moe_args
grad_out_buf = _local_scatter(grad_out.contiguous(), pos) grad_out_buf = _local_scatter(grad_out.contiguous(), pos)
if world_size > 1: if world_size > 1:
(global_grad_out_buf,) = fmoe_cuda.global_scatter( global_grad_out_buf = fmoe_cuda.global_scatter(
grad_out_buf, grad_out_buf,
local_expert_count, local_expert_count,
global_expert_count, global_expert_count,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment