"vscode:/vscode.git/clone" did not exist on "5fab6734424a78a2a4594525386cd84feb67fb50"
Commit 92f1774a authored by Rick Ho's avatar Rick Ho
Browse files

moe backward (cannot pass test)

parent c91dfad8
......@@ -22,10 +22,10 @@ std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor expert_count);
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output,
torch::Tensor input,
torch::Tensor gate,
torch::Tensor weight);
torch::Tensor grad_output_buf,
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
// C++ interface
......@@ -58,7 +58,7 @@ std::vector<torch::Tensor> moe_local_gather(
std::vector<torch::Tensor> moe_forward(
torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor weight, // [num_expert x hidden_feat x in_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count // [batch_size]
) {
CHECK_INPUT(input_buf);
......@@ -72,21 +72,20 @@ std::vector<torch::Tensor> moe_forward(
}
std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size]
torch::Tensor weight // [num_expert x out_feat x in_feat]
torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count
) {
CHECK_INPUT(grad_output);
CHECK_INPUT(input);
CHECK_INPUT(gate);
CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return moe_cuda_backward(grad_output, input, gate, weight);
return moe_cuda_backward(grad_output_buf, input_buf, weight, expert_count);
}
......
......@@ -16,21 +16,21 @@ class MOEFunction(Function):
output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
output = moe_cuda.local_gather(output_buf, pos)
variables = [inp, gate, weight, expert_count, pos]
variables = [input_buf, gate, weight, expert_count, pos]
ctx.save_for_backward(*variables)
return output[0]
@staticmethod
def backward(ctx, grad_out):
# print("grad_out", grad_out)
# print("input", ctx.saved_tensors[0])
grad_inp, grad_weight = moe_cuda.backward(
grad_out.contiguous(), *ctx.saved_tensors)
out_feat, in_feat = grad_weight.size()[1:]
# print("grad_weight_column_major", grad_weight.flatten())
grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
return grad_inp, None, grad_weight_row_major
input_buf, gate, weight, expert_count, pos = ctx.saved_tensors
grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_inp_buf, grad_weight = moe_cuda.backward(
grad_out_buf, input_buf, weight, expert_count)
grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)
return grad_inp, None, grad_weight
class MOELayer(nn.Module):
......@@ -82,9 +82,6 @@ def test_module(moe, linear, inp, gate):
moe.zero_grad()
x = linear(inp)
output = moe(x, gate)
print(output)
return output
print(output)
y = output.mean()
y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
......
......@@ -124,9 +124,7 @@ void moe_cuda_forward_impl(
scalar_t* output_buf,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
cublasOperation_t transb) {
const size_t num_expert) {
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
......@@ -151,40 +149,55 @@ void moe_cuda_forward_impl(
}
template <typename scalar_t>
void moe_cuda_grad_weight(
const scalar_t* input,
const int* gate,
const scalar_t* grad_output,
scalar_t* grad_weight, // [num_expert x out_feat x in_feat]
void moe_cuda_backward_impl(
const scalar_t* grad_output_buf,
const scalar_t* input_buf,
const scalar_t* weight,
const int* expert_count,
scalar_t* grad_input_buf,
scalar_t* grad_weight,
const size_t batch_size,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert) {
ENSURE_SMGR(smgr, num_expert);
scalar_t alpha = 1, beta = 0;
int* gate_host = new int[batch_size];
scalar_t alpha = 1, beta = 1;
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) {
// checkCudaErrors(cublasSetStream);
checkCudaErrors(cublasXgemm(smgr.handles[0],
CUBLAS_OP_N,
CUBLAS_OP_T,
out_feat,
in_feat,
1,
&alpha,
grad_output + i * out_feat,
out_feat,
input + i * in_feat,
in_feat,
&beta,
grad_weight + gate_host[i] * out_feat * in_feat,
out_feat));
}
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamSynchronize(*(smgr.streams + i)));
}
delete[] gate_host;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
cudaMemset(grad_weight + i * in_feat * out_feat, 0,
sizeof(scalar_t) * in_feat * out_feat);
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
// Backward input: g_i = w @ g_o
checkCudaErrors(cublasXgemm(smgr.handles[i],
CUBLAS_OP_N,
CUBLAS_OP_N,
in_feat, expert_count[i], out_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
grad_input_buf + in_feat * ptr, in_feat
));
// Backward weight: g_w = i @ g_o
checkCudaErrors(cublasXgemm(smgr.handles[i],
CUBLAS_OP_N,
CUBLAS_OP_T,
in_feat, out_feat, expert_count[i],
&alpha,
input_buf + in_feat * ptr, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
grad_weight + i * in_feat * out_feat, in_feat
));
ptr += expert_count[i];
}
smgr.sync();
}
......@@ -285,8 +298,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
output.data_ptr<scalar_t>(),
in_feat,
out_feat,
num_expert,
CUBLAS_OP_T
num_expert
);
}));
......@@ -294,49 +306,32 @@ std::vector<torch::Tensor> moe_cuda_forward(
}
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size]
torch::Tensor weight // [num_expert x out_feat x in_feat]
torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count
) {
const auto batch_size = input.size(0);
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld\n",
batch_size, num_expert, in_feat, out_feat);
#endif
const int device = device_of(input).value().index();
if (smgr.streams == NULL) {
smgr.setup(num_expert, device);
}
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
// grad_input is easy to compute, exactly the same as forward
/* TODO: Backward currently brokenn
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_forward_impl<scalar_t>(
grad_output.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_backward_impl<scalar_t>(
grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(),
batch_size,
out_feat,
in_feat,
num_expert,
CUBLAS_OP_N
);
}));
*/
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_grad_weight<scalar_t>(
input.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
grad_output.data_ptr<scalar_t>(),
expert_count.data_ptr<int>(),
grad_input_buf.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
batch_size,
in_feat,
......@@ -345,7 +340,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
);
}));
return {grad_input, grad_weight};
return {grad_input_buf, grad_weight};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment