"src/array/cuda/array_cumsum.hip" did not exist on "0ff7127a0fff730f3c41a8ea3e967c1155993a2f"
Commit a526f438 authored by Rick Ho's avatar Rick Ho
Browse files

single node use torch cuda expert count

parent bc8e8181
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
const int* offset, const scalar_t** ptrs) { const long* offset, const scalar_t** ptrs) {
size_t idx = threadIdx.x + blockDim.x * blockIdx.x; size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < n) { if (idx < n) {
ptrs[idx] = base + stride * offset[idx]; ptrs[idx] = base + stride * offset[idx];
...@@ -29,7 +29,7 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, ...@@ -29,7 +29,7 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void batch_scatter_kernel(size_t wid, const int* pos, void batch_scatter_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) { const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * blockIdx.x; inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x]; oubuf += wid * pos[blockIdx.x];
...@@ -77,7 +77,7 @@ void moe_cuda_expert_count_impl( ...@@ -77,7 +77,7 @@ void moe_cuda_expert_count_impl(
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_local_scatter_impl( void moe_cuda_local_scatter_impl(
const scalar_t* input, const scalar_t* input,
const int* d_pos, const long* d_pos,
scalar_t* input_buf, scalar_t* input_buf,
const long batch_size, const long batch_size,
const long in_feat, const long in_feat,
...@@ -90,7 +90,7 @@ void moe_cuda_local_scatter_impl( ...@@ -90,7 +90,7 @@ void moe_cuda_local_scatter_impl(
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void batch_gather_kernel(size_t wid, const int* pos, void batch_gather_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) { const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x]; inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x; oubuf += wid * blockIdx.x;
...@@ -102,7 +102,7 @@ void batch_gather_kernel(size_t wid, const int* pos, ...@@ -102,7 +102,7 @@ void batch_gather_kernel(size_t wid, const int* pos,
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_local_gather_impl( void moe_cuda_local_gather_impl(
const scalar_t* output_buf, const scalar_t* output_buf,
const int* d_pos, const long* d_pos,
scalar_t* output, scalar_t* output,
const size_t batch_size, const size_t batch_size,
const size_t out_feat, const size_t out_feat,
...@@ -117,7 +117,7 @@ template <typename scalar_t> ...@@ -117,7 +117,7 @@ template <typename scalar_t>
void moe_cuda_forward_impl( void moe_cuda_forward_impl(
const scalar_t* input_buf, const scalar_t* input_buf,
const scalar_t* weight, const scalar_t* weight,
const int* expert_count, const long* expert_count,
scalar_t* output_buf, scalar_t* output_buf,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
...@@ -152,7 +152,7 @@ void moe_cuda_backward_impl( ...@@ -152,7 +152,7 @@ void moe_cuda_backward_impl(
const scalar_t* grad_output_buf, const scalar_t* grad_output_buf,
const scalar_t* input_buf, const scalar_t* input_buf,
const scalar_t* weight, const scalar_t* weight,
const int* expert_count, const long* expert_count,
scalar_t* grad_input_buf, scalar_t* grad_input_buf,
scalar_t* grad_weight, scalar_t* grad_weight,
const size_t batch_size, const size_t batch_size,
...@@ -237,7 +237,7 @@ std::vector<torch::Tensor> moe_cuda_local_scatter( ...@@ -237,7 +237,7 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
([&] { ([&] {
moe_cuda_local_scatter_impl<scalar_t>( moe_cuda_local_scatter_impl<scalar_t>(
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
pos.data_ptr<int>(), pos.data_ptr<long>(),
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
batch_size, batch_size,
in_feat, in_feat,
...@@ -259,7 +259,7 @@ std::vector<torch::Tensor> moe_cuda_local_gather( ...@@ -259,7 +259,7 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
([&] { ([&] {
moe_cuda_local_gather_impl<scalar_t>( moe_cuda_local_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(), output_buf.data_ptr<scalar_t>(),
pos.data_ptr<int>(), pos.data_ptr<long>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
batch_size, batch_size,
out_feat, out_feat,
...@@ -293,7 +293,7 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -293,7 +293,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
moe_cuda_forward_impl<scalar_t>( moe_cuda_forward_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
expert_count.data_ptr<int>(), expert_count.data_ptr<long>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
in_feat, in_feat,
out_feat, out_feat,
...@@ -331,7 +331,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -331,7 +331,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
grad_output_buf.data_ptr<scalar_t>(), grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
expert_count.data_ptr<int>(), expert_count.data_ptr<long>(),
grad_input_buf.data_ptr<scalar_t>(), grad_input_buf.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(), grad_weight.data_ptr<scalar_t>(),
batch_size, batch_size,
......
...@@ -6,12 +6,19 @@ import fmoe_cuda ...@@ -6,12 +6,19 @@ import fmoe_cuda
class MOELocal(Function): class MOELocal(Function):
@staticmethod @staticmethod
def forward(ctx, inp, gate, weight): def forward(ctx, inp, gate, weight):
expert_count, pos = fmoe_cuda.expert_count(gate, weight.shape[0]) _, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
expert_count = torch.zeros(weight.shape[0], device=weight.device,
dtype=torch.long)
expert_count.index_put_((gate_idx.long(), ), gate_count)
# expert_count, pos = fmoe_cuda.expert_count(gate, weight.shape[0])
ecc = expert_count.cpu()
input_buf, = fmoe_cuda.local_scatter(inp, pos) input_buf, = fmoe_cuda.local_scatter(inp, pos)
output_buf, = fmoe_cuda.forward(input_buf, weight, expert_count) output_buf, = fmoe_cuda.forward(input_buf, weight, ecc)
output = fmoe_cuda.local_gather(output_buf, pos) output = fmoe_cuda.local_gather(output_buf, pos)
variables = [input_buf, gate, weight, expert_count, pos] variables = [input_buf, gate, weight, ecc, pos]
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return output[0] return output[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment