Commit d2039fc7 authored by Rick Ho's avatar Rick Ho
Browse files

swap local scatter and gather kernel functions

parent 14c0eab4
......@@ -31,8 +31,8 @@ template <typename scalar_t>
__global__
void batch_scatter_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x];
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
......@@ -92,8 +92,8 @@ template <typename scalar_t>
__global__
void batch_gather_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x];
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
......
......@@ -29,7 +29,7 @@ class MOEScatter(Function):
@staticmethod
def forward(ctx, inp, pos, local_expert_count, global_expert_count,
fwd_batch_size, world_size):
local_input_buf, = fmoe_cuda.local_gather(inp, pos)
local_input_buf, = fmoe_cuda.local_scatter(inp, pos)
if world_size > 1:
global_input_buf, = fmoe_cuda.global_scatter(local_input_buf,
local_expert_count, global_expert_count,
......@@ -52,7 +52,7 @@ class MOEScatter(Function):
local_batch_size, world_size)
else:
local_grad_in = global_grad_in
grad_in, = fmoe_cuda.local_scatter(local_grad_in, pos)
grad_in, = fmoe_cuda.local_gather(local_grad_in, pos)
return grad_in, None, None, None, None, None
......@@ -83,7 +83,7 @@ class MOEGather(Function):
local_batch_size, world_size)
else:
local_output_buf = global_output_buf
output, = fmoe_cuda.local_scatter(local_output_buf, pos)
output, = fmoe_cuda.local_gather(local_output_buf, pos)
ctx.moe_args = local_batch_size, global_output_buf.shape[0], world_size
variables = (pos, local_expert_count, global_expert_count)
......@@ -94,7 +94,7 @@ class MOEGather(Function):
def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensors
local_batch_size, fwd_batch_size, world_size = ctx.moe_args
grad_out_buf, = fmoe_cuda.local_gather(grad_out.contiguous(), pos)
grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
if world_size > 1:
global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment