"git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "42067ef2628320aa28cc79eb7d8bca97088f934e"
Commit 92f1774a authored by Rick Ho's avatar Rick Ho
Browse files

moe backward (cannot pass test)

parent c91dfad8
...@@ -22,10 +22,10 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -22,10 +22,10 @@ std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor expert_count); torch::Tensor expert_count);
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output, torch::Tensor grad_output_buf,
torch::Tensor input, torch::Tensor input_buf,
torch::Tensor gate, torch::Tensor weight,
torch::Tensor weight); torch::Tensor expert_count);
// C++ interface // C++ interface
...@@ -58,7 +58,7 @@ std::vector<torch::Tensor> moe_local_gather( ...@@ -58,7 +58,7 @@ std::vector<torch::Tensor> moe_local_gather(
std::vector<torch::Tensor> moe_forward( std::vector<torch::Tensor> moe_forward(
torch::Tensor input_buf, // [batch_size x in_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor weight, // [num_expert x hidden_feat x in_feat] torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count // [batch_size] torch::Tensor expert_count // [batch_size]
) { ) {
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
...@@ -72,21 +72,20 @@ std::vector<torch::Tensor> moe_forward( ...@@ -72,21 +72,20 @@ std::vector<torch::Tensor> moe_forward(
} }
std::vector<torch::Tensor> moe_backward( std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size] torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor weight // [num_expert x out_feat x in_feat] torch::Tensor expert_count
) { ) {
CHECK_INPUT(grad_output); CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input); CHECK_INPUT(input_buf);
CHECK_INPUT(gate);
CHECK_INPUT(weight); CHECK_INPUT(weight);
/* /*
The bias term should have been merged into weight. Note the following fact that The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x] Wx+b = [W b] [x]
[1] [1]
*/ */
return moe_cuda_backward(grad_output, input, gate, weight); return moe_cuda_backward(grad_output_buf, input_buf, weight, expert_count);
} }
......
...@@ -16,21 +16,21 @@ class MOEFunction(Function): ...@@ -16,21 +16,21 @@ class MOEFunction(Function):
output_buf, = moe_cuda.forward(input_buf, weight, expert_count) output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
output = moe_cuda.local_gather(output_buf, pos) output = moe_cuda.local_gather(output_buf, pos)
variables = [inp, gate, weight, expert_count, pos] variables = [input_buf, gate, weight, expert_count, pos]
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return output[0] return output[0]
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
# print("grad_out", grad_out) input_buf, gate, weight, expert_count, pos = ctx.saved_tensors
# print("input", ctx.saved_tensors[0])
grad_inp, grad_weight = moe_cuda.backward( grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_out.contiguous(), *ctx.saved_tensors) grad_inp_buf, grad_weight = moe_cuda.backward(
out_feat, in_feat = grad_weight.size()[1:] grad_out_buf, input_buf, weight, expert_count)
# print("grad_weight_column_major", grad_weight.flatten()) grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)
grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
return grad_inp, None, grad_weight_row_major return grad_inp, None, grad_weight
class MOELayer(nn.Module): class MOELayer(nn.Module):
...@@ -82,9 +82,6 @@ def test_module(moe, linear, inp, gate): ...@@ -82,9 +82,6 @@ def test_module(moe, linear, inp, gate):
moe.zero_grad() moe.zero_grad()
x = linear(inp) x = linear(inp)
output = moe(x, gate) output = moe(x, gate)
print(output)
return output
print(output)
y = output.mean() y = output.mean()
y.backward() y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
......
...@@ -124,9 +124,7 @@ void moe_cuda_forward_impl( ...@@ -124,9 +124,7 @@ void moe_cuda_forward_impl(
scalar_t* output_buf, scalar_t* output_buf,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert, const size_t num_expert) {
cublasOperation_t transb) {
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
...@@ -151,40 +149,55 @@ void moe_cuda_forward_impl( ...@@ -151,40 +149,55 @@ void moe_cuda_forward_impl(
} }
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_grad_weight( void moe_cuda_backward_impl(
const scalar_t* input, const scalar_t* grad_output_buf,
const int* gate, const scalar_t* input_buf,
const scalar_t* grad_output, const scalar_t* weight,
scalar_t* grad_weight, // [num_expert x out_feat x in_feat] const int* expert_count,
scalar_t* grad_input_buf,
scalar_t* grad_weight,
const size_t batch_size, const size_t batch_size,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert) { const size_t num_expert) {
ENSURE_SMGR(smgr, num_expert);
scalar_t alpha = 1, beta = 0;
int* gate_host = new int[batch_size]; for (int i = 0, ptr = 0; i < num_expert; ++i) {
scalar_t alpha = 1, beta = 1; if (expert_count[i] == 0) {
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost)); cudaMemset(grad_weight + i * in_feat * out_feat, 0,
for (size_t i=0; i<batch_size; ++i) { sizeof(scalar_t) * in_feat * out_feat);
// checkCudaErrors(cublasSetStream); continue;
checkCudaErrors(cublasXgemm(smgr.handles[0], }
CUBLAS_OP_N, // Use T(B) x T(A) = T(C) to produce row-major C
CUBLAS_OP_T,
out_feat, // Backward input: g_i = w @ g_o
in_feat, checkCudaErrors(cublasXgemm(smgr.handles[i],
1, CUBLAS_OP_N,
&alpha, CUBLAS_OP_N,
grad_output + i * out_feat, in_feat, expert_count[i], out_feat,
out_feat, &alpha,
input + i * in_feat, weight + i * in_feat * out_feat, in_feat,
in_feat, grad_output_buf + ptr * out_feat, out_feat,
&beta, &beta,
grad_weight + gate_host[i] * out_feat * in_feat, grad_input_buf + in_feat * ptr, in_feat
out_feat)); ));
}
for (size_t i=0; i<num_expert; ++i) { // Backward weight: g_w = i @ g_o
checkCudaErrors(cudaStreamSynchronize(*(smgr.streams + i))); checkCudaErrors(cublasXgemm(smgr.handles[i],
} CUBLAS_OP_N,
delete[] gate_host; CUBLAS_OP_T,
in_feat, out_feat, expert_count[i],
&alpha,
input_buf + in_feat * ptr, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
grad_weight + i * in_feat * out_feat, in_feat
));
ptr += expert_count[i];
}
smgr.sync();
} }
...@@ -285,8 +298,7 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -285,8 +298,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
in_feat, in_feat,
out_feat, out_feat,
num_expert, num_expert
CUBLAS_OP_T
); );
})); }));
...@@ -294,49 +306,32 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -294,49 +306,32 @@ std::vector<torch::Tensor> moe_cuda_forward(
} }
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor gate, // [batch_size] torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor weight // [num_expert x out_feat x in_feat] torch::Tensor expert_count
) { ) {
const auto batch_size = input.size(0); const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld\n",
batch_size, num_expert, in_feat, out_feat);
#endif #endif
const int device = device_of(input).value().index();
if (smgr.streams == NULL) {
smgr.setup(num_expert, device);
}
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
// grad_input is easy to compute, exactly the same as forward AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
/* TODO: Backward currently brokenn moe_cuda_backward_impl<scalar_t>(
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] { grad_output_buf.data_ptr<scalar_t>(),
moe_cuda_forward_impl<scalar_t>( input_buf.data_ptr<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(), expert_count.data_ptr<int>(),
batch_size, grad_input_buf.data_ptr<scalar_t>(),
out_feat,
in_feat,
num_expert,
CUBLAS_OP_N
);
}));
*/
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_grad_weight<scalar_t>(
input.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
grad_output.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(), grad_weight.data_ptr<scalar_t>(),
batch_size, batch_size,
in_feat, in_feat,
...@@ -345,7 +340,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -345,7 +340,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
); );
})); }));
return {grad_input, grad_weight}; return {grad_input_buf, grad_weight};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment