Commit 52e57316 authored by Rick Ho's avatar Rick Ho
Browse files

remove local scatter and gather in cuda

parent 414a2f86
......@@ -22,12 +22,6 @@ void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t);
#endif // FMOE_USE_NCCL
// local_exchange
std::vector<torch::Tensor> _local_scatter(
torch::Tensor input,
torch::Tensor pos);
std::vector<torch::Tensor> _local_gather(
torch::Tensor output_buf,
torch::Tensor pos);
void _assign_pos(
torch::Tensor cum_count,
torch::Tensor gate,
......@@ -60,8 +54,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm");
#endif
m.def("local_scatter", &_local_scatter, "FastMoE local scatter (CUDA)");
m.def("local_gather", &_local_gather, "FastMoE local gather (CUDA)");
m.def("assign_pos_", &_assign_pos, "FastMoE assign pos by gate(CUDA)");
m.def("linear_forward", &_linear_forward, "FastMoE forward (CUDA)");
......
......@@ -2,56 +2,6 @@
#include "utils/fmoe_utils.h"
#include <torch/extension.h>
std::vector<torch::Tensor> _local_scatter(
torch::Tensor input,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(input.device().index());
const auto batch_size = pos.size(0);
const auto in_feat = input.size(1);
auto opt = torch::TensorOptions()
.dtype(input.dtype())
.device(input.device());
auto input_buf = torch::empty({batch_size, in_feat}, opt);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fmoe_local_scatter",
([&] {
fmoe_cuda_local_scatter_impl<scalar_t>(
input.data_ptr<scalar_t>(),
pos.data_ptr<long>(),
input_buf.data_ptr<scalar_t>(),
batch_size,
in_feat,
smgr);
}));
return {input_buf,};
}
std::vector<torch::Tensor> _local_gather(
torch::Tensor output_buf,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(output_buf.device().index());
const auto batch_size = pos.size(0);
const auto out_feat = output_buf.size(1);
auto opt = torch::TensorOptions()
.dtype(output_buf.dtype())
.device(output_buf.device());
auto output = torch::empty({batch_size, out_feat}, opt);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "fmoe_local_gather",
([&] {
fmoe_cuda_local_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
pos.data_ptr<long>(),
output.data_ptr<scalar_t>(),
batch_size,
out_feat,
smgr);
}));
return {output,};
}
void _assign_pos(
torch::Tensor cum_count,
torch::Tensor gate,
......
......@@ -2,56 +2,6 @@
#include "utils/helper_cuda.h"
#include "utils/fmoe_utils.h"
template <typename scalar_t>
__global__
void batch_scatter_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
void fmoe_cuda_local_scatter_impl(
const scalar_t* input,
const long* d_pos,
scalar_t* input_buf,
const long batch_size,
const long in_feat,
CudaStreamManager* smgr) {
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
input_buf);
smgr->sync(1);
}
template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x];
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
void fmoe_cuda_local_gather_impl(
const scalar_t* output_buf,
const long* d_pos,
scalar_t* output,
const size_t batch_size,
const size_t out_feat,
CudaStreamManager* smgr) {
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
output);
smgr->sync(1);
}
__global__
void assign_pos_kernel(int* cum_count, const long* gate, long* pos,
size_t numel, size_t topk) {
......@@ -60,7 +10,7 @@ void assign_pos_kernel(int* cum_count, const long* gate, long* pos,
long gate_idx = gate[idx];
if (gate_idx > -1) {
int p = atomicSub(cum_count + gate_idx, 1);
pos[p] = (long)idx;
pos[p - 1] = (long)idx;
}
}
}
......@@ -71,7 +21,7 @@ void fmoe_cuda_assign_pos_impl(
CudaStreamManager* smgr) {
size_t numel = batch_size * topk;
assign_pos_kernel
<<<CEIL(numel, 256), 256, 0, smgr->stream(0)>>>(cum_count, gate, pos,
numel, topk);
<<<CEIL(numel, 256), 256, 0, smgr->stream(0)>>>
(cum_count, gate, pos, numel, topk);
smgr->sync(1);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment