Commit 980cf4b6 authored by Rick Ho's avatar Rick Ho
Browse files

use customized pos assignment kernel to support -1

parent a468db2b
......@@ -22,15 +22,16 @@ void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t);
#endif // FMOE_USE_NCCL
// local_exchange
std::vector<torch::Tensor> _expert_count(
torch::Tensor gate,
size_t num_expert);
std::vector<torch::Tensor> _local_scatter(
torch::Tensor input,
torch::Tensor pos);
std::vector<torch::Tensor> _local_gather(
torch::Tensor output_buf,
torch::Tensor pos);
void _assign_pos(
torch::Tensor cum_count,
torch::Tensor gate,
torch::Tensor pos);
// parallel_linear
std::vector<torch::Tensor> _linear_forward(
......@@ -59,9 +60,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm");
#endif
m.def("expert_count", &_expert_count, "FastMoE expert count (CUDA)");
m.def("local_scatter", &_local_scatter, "FastMoE local scatter (CUDA)");
m.def("local_gather", &_local_gather, "FastMoE local gather (CUDA)");
m.def("assign_pos_", &_assign_pos, "FastMoE assign pos by gate(CUDA)");
m.def("linear_forward", &_linear_forward, "FastMoE forward (CUDA)");
m.def("linear_backward", &_linear_backward, "FastMoE backward (CUDA)");
......
......@@ -2,28 +2,6 @@
#include "utils/fmoe_utils.h"
#include <torch/extension.h>
std::vector<torch::Tensor> _expert_count(
torch::Tensor gate,
size_t num_expert) {
const auto batch_size = gate.size(0);
auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
auto expert_count = torch::empty(num_expert, ec_options);
auto pos_options = torch::TensorOptions()
.device(gate.device())
.dtype(torch::kInt32);
auto pos = torch::empty(batch_size, pos_options);
fmoe_cuda_expert_count_impl(
gate.data_ptr<int>(),
expert_count.data_ptr<int>(),
pos.data_ptr<int>(),
num_expert,
batch_size);
return {expert_count, pos};
}
std::vector<torch::Tensor> _local_scatter(
torch::Tensor input,
torch::Tensor pos) {
......@@ -73,3 +51,20 @@ std::vector<torch::Tensor> _local_gather(
}));
return {output,};
}
void _assign_pos(
torch::Tensor cum_count,
torch::Tensor gate,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(cum_count.device().index());
auto gate_shp = gate.sizes();
size_t batch_size = gate_shp[0], topk = 1;
if (gate_shp.size() == 2) {
topk = gate_shp[1];
}
fmoe_cuda_assign_pos_impl(
cum_count.data_ptr<int>(),
gate.data_ptr<long>(),
pos.data_ptr<long>(),
batch_size, topk, smgr);
}
#include "stream_manager.h"
#include "utils/helper_cuda.h"
template <typename scalar_t>
__global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
const long* offset, const scalar_t** ptrs) {
size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < n) {
ptrs[idx] = base + stride * offset[idx];
}
}
#include "utils/fmoe_utils.h"
template <typename scalar_t>
__global__
......@@ -22,42 +13,6 @@ void batch_scatter_kernel(size_t wid, const long* pos,
}
}
void fmoe_cuda_expert_count_impl(
const int* d_gate,
int* expert_count,
int* d_pos,
const size_t num_expert,
const size_t batch_size) {
int *gate = new int[batch_size];
int *expert_ptr = new int[num_expert];
memset(expert_count, 0, sizeof(int) * num_expert);
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost));
for (int i = 0; i < batch_size; ++i) {
++expert_count[gate[i]];
}
expert_ptr[0] = 0;
for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
}
int *pos = new int[batch_size];
for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++;
}
for (int i = num_expert - 1; i > 0; --i) {
expert_ptr[i] = expert_ptr[i - 1];
}
expert_ptr[0] = 0;
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice));
delete [] gate;
delete [] expert_ptr;
}
template <typename scalar_t>
void fmoe_cuda_local_scatter_impl(
const scalar_t* input,
......@@ -96,3 +51,27 @@ void fmoe_cuda_local_gather_impl(
output);
smgr->sync(1);
}
__global__
void assign_pos_kernel(int* cum_count, const long* gate, long* pos,
size_t numel, size_t topk) {
size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < numel) {
long gate_idx = gate[idx];
if (gate_idx > -1) {
int p = atomicSub(cum_count + gate_idx, 1);
pos[p] = (long)idx;
}
}
}
void fmoe_cuda_assign_pos_impl(
int* cum_count, const long* gate, long* pos,
const size_t batch_size, const size_t topk,
CudaStreamManager* smgr) {
size_t numel = batch_size * topk;
assign_pos_kernel
<<<CEIL(numel, 256), 256, 0, smgr->stream(0)>>>(cum_count, gate, pos,
numel, topk);
smgr->sync(1);
}
......@@ -16,15 +16,17 @@ def _ensure_nccl(t, comm=None):
fmoe_cuda.ensure_nccl(comm, t)
def count_by_gate(gate, num_expert, world_size):
# TODO: support -1 in gate, which means ignore this input
def count_by_gate(gate, num_expert, world_size, require_pos=True):
with torch.no_grad():
_, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
flatten_gate = gate.view(-1)
eff_gate = flatten_gate[flatten_gate != -1]
local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.long
)
local_expert_count.index_put_((gate_idx.long(),), gate_count)
ones = torch.ones(eff_gate.numel(),
device=gate.device, dtype=torch.long)
local_expert_count.index_add_(0, eff_gate, ones)
if world_size > 1:
_ensure_nccl(gate)
......@@ -33,6 +35,13 @@ def count_by_gate(gate, num_expert, world_size):
)
else:
global_expert_count = local_expert_count
if not require_pos:
pos = None
else:
lec_cum = torch.cumsum(local_expert_count, dim=0).int()
pos_size = lec_cum[-1].item()
pos = torch.empty((pos_size,), device=gate.device, dtype=torch.long)
fmoe_cuda.assign_pos_(lec_cum, gate, pos)
return pos, local_expert_count, global_expert_count
......
......@@ -36,7 +36,6 @@ class NaiveGate(BaseGate):
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
if return_all_scores:
return gate_top_k_idx, gate_top_k_val, gate
......
......@@ -10,7 +10,8 @@ def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
capacity = torch.ones(num_expert, dtype=torch.int32,
device=topk_idx.device) * capacity
pos, lec, gec = count_by_gate(topk_idx.reshape(-1), num_expert, world_size)
pos, lec, gec = count_by_gate(topk_idx, num_expert, world_size,
require_pos=False)
new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
num_expert, world_size)
if world_size > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment