Commit 56cb8c15 authored by Rick Ho's avatar Rick Ho
Browse files

gshard prune gate cuda code and test

parent 3e8c263c
...@@ -5,13 +5,25 @@ ...@@ -5,13 +5,25 @@
* note that due to limit of cuda atomic operator, capacity should be int32 * note that due to limit of cuda atomic operator, capacity should be int32
*/ */
std::vector<torch::Tensor> _limit_by_capacity( std::vector<torch::Tensor> _limit_by_capacity(
torch::Tensor expert_count, torch::Tensor capacity, torch::Tensor expert_count, torch::Tensor capacity,
long n_expert, long n_experts) { long n_expert, long n_worker) {
auto expert_count_ack = torch::empty_like(expert_count); auto expert_count_ack = torch::empty_like(expert_count);
auto smgr = getCudaStreamManager(expert_count.device().index()); auto smgr = getCudaStreamManager(expert_count.device().index());
fmoe_cuda_limit_by_capacity_impl( fmoe_cuda_limit_by_capacity_impl(
expert_count.data_ptr<long>(), expert_count.data_ptr<long>(),
capacity.data_ptr<int>(), capacity.data_ptr<int>(),
expert_count_ack.data_ptr<long>(), expert_count_ack.data_ptr<long>(),
n_expert, n_workers, smgr); n_expert, n_worker, smgr);
return {expert_count_ack};
}
void _prune_gate_by_capacity(
torch::Tensor gate_idx, torch::Tensor expert_count,
long n_expert, long n_worker) {
auto smgr = getCudaStreamManager(expert_count.device().index());
auto batch_size = gate_idx.numel();
fmoe_cuda_prune_gate_by_capacity_impl(
gate_idx.data_ptr<long>(),
expert_count.data_ptr<int>(),
batch_size, n_expert, n_worker, smgr);
} }
...@@ -4,28 +4,51 @@ ...@@ -4,28 +4,51 @@
__global__ __global__
void limit_by_capacity_kernel(const long* ec, int* cap, long* eca, void limit_by_capacity_kernel(const long* ec, int* cap, long* eca,
const long n_expert, const long n_worker) { const long n_expert, const long n_worker) {
int eid = blockIdx.y; int eid = blockIdx.y;
int wid = blockIdx.x * blockDim.x + threadIdx.x; int wid = blockIdx.x * blockDim.x + threadIdx.x;
if (wid < n_worker) { if (wid < n_worker) {
int proposal = ec[wid * n_expert + eid]; int proposal = ec[wid * n_expert + eid];
int cap_left = atomicSub(cap + eid, proposal); int cap_left = atomicSub(cap + eid, proposal);
if (cap_left >= proposal) { if (cap_left >= proposal) {
eca[wid * n_expert + eid] = proposal; eca[wid * n_expert + eid] = proposal;
} else if (cap_left >= 0) { } else if (cap_left >= 0) {
eca[wid * n_expert + eid] = cap_left; eca[wid * n_expert + eid] = cap_left;
} else { } else {
eca[wid * n_expert + eid] = 0; eca[wid * n_expert + eid] = 0;
} }
} }
} }
void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap, void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
long* eca, const long n_expert, const long n_worker, long* eca, const long n_expert, const long n_worker,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(n_worker, 1024), n_expert); dim3 grid_dim(CEIL(n_worker, 1024), n_expert);
dim3 block_dim(1024); dim3 block_dim(1024);
limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>( limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
ec, cap, eca, n_expert, n_worker); ec, cap, eca, n_expert, n_worker);
smgr->sync(1); smgr->sync(1);
}
__global__
void prune_gate_by_capacity_kernel(long* gate_idx, int* ec,
const long batch_size, const long n_expert, const long n_worker) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < batch_size) {
int orig_cap = atomicSub(ec + gate_idx[i], 1);
if (orig_cap <= 0) {
gate_idx[i] = -1;
}
}
}
void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, int* ec,
const long batch_size, const long n_expert, const long n_worker,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(batch_size, 1024));
dim3 block_dim(1024);
prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
gate_idx, ec, batch_size, n_expert, n_worker
);
smgr->sync(1);
} }
...@@ -45,8 +45,11 @@ std::vector<torch::Tensor> _linear_backward( ...@@ -45,8 +45,11 @@ std::vector<torch::Tensor> _linear_backward(
// balancing // balancing
std::vector<torch::Tensor> _limit_by_capacity( std::vector<torch::Tensor> _limit_by_capacity(
torch::Tensor expert_count, torch::Tensor capacity, torch::Tensor expert_count, torch::Tensor capacity,
long n_expert, long n_experts) { long n_expert, long n_experts);
void _prune_gate_by_capacity(
torch::Tensor gate_idx, torch::Tensor expert_count,
long n_expert, long n_worker);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef FMOE_USE_NCCL #ifdef FMOE_USE_NCCL
...@@ -63,5 +66,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -63,5 +66,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_forward", &_linear_forward, "FastMoE forward (CUDA)"); m.def("linear_forward", &_linear_forward, "FastMoE forward (CUDA)");
m.def("linear_backward", &_linear_backward, "FastMoE backward (CUDA)"); m.def("linear_backward", &_linear_backward, "FastMoE backward (CUDA)");
m.def("limit_by_capacity", &_limit_by_capacity, "FastMoE limit experts by capacity(CUDA)"); m.def("limit_by_capacity", &_limit_by_capacity, "FastMoE limit experts by capacity(CUDA)");
m.def("prune_gate_by_capacity", &_prune_gate_by_capacity, "FastMoE prune gate by capacity(CUDA)");
} }
default : test_balancing default : test_prune_gate test_limit
test_% : %.cu test_% : %.cu
nvcc $< ../stream_manager.cpp -lcublas -o $@ nvcc $< ../stream_manager.cpp -lcublas -o $@
#include "../balancing.cuh"
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cuda.h>
#include <cuda_runtime.h>
int main(int argc, char* args[]) {
int n_worker = atoi(args[1]);
int n_expert = atoi(args[2]);
int cap_v = atoi(args[3]);
int tot_expert = n_worker * n_expert;
long* lec = new long[tot_expert];
for (int i = 0; i < tot_expert; ++i) {
lec[i] = i;
}
long* g_lec;
cudaMalloc(&g_lec, sizeof(long) * tot_expert);
cudaMemcpy(g_lec, lec, sizeof(long) * tot_expert, cudaMemcpyHostToDevice);
int* cap = new int[n_expert];
for (int i = 0; i < n_expert; ++i) {
cap[i] = cap_v;
}
int* g_cap;
cudaMalloc(&g_cap, sizeof(int) * n_expert);
cudaMemcpy(g_cap, cap, sizeof(int) * n_expert, cudaMemcpyHostToDevice);
long* eca = new long[tot_expert];
long* g_eca;
cudaMalloc(&g_eca, sizeof(long) * tot_expert);
auto smgr = getCudaStreamManager(0);
fmoe_cuda_limit_by_capacity_impl(g_lec, g_cap, g_eca, n_expert, n_worker, smgr);
cudaMemcpy(cap, g_cap, sizeof(int) * n_expert, cudaMemcpyDeviceToHost);
cudaMemcpy(eca, g_eca, sizeof(long) * tot_expert, cudaMemcpyDeviceToHost);
printf("%d\n", cap[0]);
for (int i = 0; i < tot_expert; ++i) {
printf("%ld %ld\n", lec[i], eca[i]);
}
}
#include "../balancing.cuh"
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cuda.h>
#include <cuda_runtime.h>
int main(int argc, char* args[]) {
int n_worker = atoi(args[1]);
int n_expert = atoi(args[2]);
int cap_v = atoi(args[3]);
int tot_expert = n_worker * n_expert;
long* lec = new long[tot_expert];
for (int i = 0; i < tot_expert; ++i) {
lec[i] = i;
}
long* g_lec;
cudaMalloc(&g_lec, sizeof(long) * tot_expert);
cudaMemcpy(g_lec, lec, sizeof(long) * tot_expert, cudaMemcpyHostToDevice);
int* cap = new int[n_expert];
for (int i = 0; i < n_expert; ++i) {
cap[i] = cap_v;
}
int* g_cap;
cudaMalloc(&g_cap, sizeof(int) * n_expert);
cudaMemcpy(g_cap, cap, sizeof(int) * n_expert, cudaMemcpyHostToDevice);
long* eca = new long[tot_expert];
long* g_eca;
cudaMalloc(&g_eca, sizeof(long) * tot_expert);
auto smgr = getCudaStreamManager(0);
fmoe_cuda_limit_by_capacity_impl(g_lec, g_cap, g_eca, n_expert, n_worker, smgr);
cudaMemcpy(cap, g_cap, sizeof(int) * n_expert, cudaMemcpyDeviceToHost);
cudaMemcpy(eca, g_eca, sizeof(long) * tot_expert, cudaMemcpyDeviceToHost);
printf("%d\n", cap[0]);
for (int i = 0; i < tot_expert; ++i) {
printf("%ld %ld\n", lec[i], eca[i]);
}
}
#include "../balancing.cuh"
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cuda.h>
#include <cuda_runtime.h>
int main(int argc, char* args[]) {
int n_worker = atoi(args[1]);
int n_expert = atoi(args[2]);
int batch_size = atoi(args[3]);
int tot_expert = n_worker * n_expert;
long* gate_idx = new long[batch_size];
long* n_gate_idx = new long[batch_size];
int* lec = new int[tot_expert];
memset(lec, 0, sizeof(int) * tot_expert);
for (int i = 0; i < batch_size; ++i) {
gate_idx[i] = rand() % tot_expert;
++lec[gate_idx[i]];
}
for (int i = 0; i < tot_expert; ++i) {
lec[i] >>= 1;
}
int* g_lec;
cudaMalloc(&g_lec, sizeof(int) * tot_expert);
cudaMemcpy(g_lec, lec, sizeof(int) * tot_expert, cudaMemcpyHostToDevice);
long* g_gate_idx;
cudaMalloc(&g_gate_idx, sizeof(long) * batch_size);
cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * batch_size, cudaMemcpyHostToDevice);
auto smgr = getCudaStreamManager(0);
fmoe_cuda_prune_gate_by_capacity_impl(g_gate_idx, g_lec,
batch_size, n_expert, n_worker, smgr);
cudaMemcpy(n_gate_idx, g_gate_idx, sizeof(long) * batch_size, cudaMemcpyDeviceToHost);
for (int i = 0; i < batch_size; ++i) {
printf("%ld %ld (%d)\n", gate_idx[i], n_gate_idx[i], lec[gate_idx[i]]);
}
}
...@@ -34,13 +34,15 @@ class GShardGate(NaiveGate): ...@@ -34,13 +34,15 @@ class GShardGate(NaiveGate):
capacity *= math.ceil(cap_rate * x.shape[0]) capacity *= math.ceil(cap_rate * x.shape[0])
pos, lec, gec = count_by_gate(gate, self.num_expert, self.world_size) pos, lec, gec = count_by_gate(gate, self.num_expert, self.world_size)
new_gec = fmoe_native.limit_by_capacity(gec, capacity, new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
self.num_expert, self.world_size) self.num_expert, self.world_size)
if self.world_size > 1: if self.world_size > 1:
new_lec = fmoe_native.expert_exchange(new_gec, new_lec = fmoe_native.expert_exchange(new_gec,
self.num_expert, self.world_size) self.num_expert, self.world_size)
else: else:
new_lec = new_gec new_lec = new_gec
# TODO: re-assign gate
fmoe_native.prune_gate_by_capacity(topk_idx,
new_lec.to(torch.int32), self.num_expert, self.world_size)
return topk_idx, topk_val return topk_idx, topk_val
...@@ -27,6 +27,7 @@ if __name__ == '__main__': ...@@ -27,6 +27,7 @@ if __name__ == '__main__':
sources=[ sources=[
'cuda/stream_manager.cpp', 'cuda/stream_manager.cpp',
'cuda/local_exchange.cu', 'cuda/local_exchange.cu',
'cuda/balancing.cu',
'cuda/global_exchange.cpp', 'cuda/global_exchange.cpp',
'cuda/parallel_linear.cpp', 'cuda/parallel_linear.cpp',
'cuda/fmoe_cuda.cpp', 'cuda/fmoe_cuda.cpp',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment