Commit 5d076dcf authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

update

parent 42b825fe
...@@ -21,10 +21,13 @@ class MOEFunction(Function): ...@@ -21,10 +21,13 @@ class MOEFunction(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
print("grad_out", grad_out)
print("input", ctx.saved_tensors[0])
grad_inp, grad_weight = moe_cuda.backward( grad_inp, grad_weight = moe_cuda.backward(
grad_out.contiguous(), *ctx.saved_tensors) grad_out.contiguous(), *ctx.saved_tensors)
out_feat, in_feat = grad_weight.size()[1:] out_feat, in_feat = grad_weight.size()[1:]
grad_weight_row_major = grad_weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat) print("grad_weight_column_major", grad_weight.flatten())
grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
return grad_inp, None, grad_weight_row_major return grad_inp, None, grad_weight_row_major
...@@ -47,9 +50,9 @@ class MOELayer(nn.Module): ...@@ -47,9 +50,9 @@ class MOELayer(nn.Module):
return MOEFunction.apply(inp, gate, self.weight) return MOEFunction.apply(inp, gate, self.weight)
class MOELayer_einsum(nn.Module): class MOELayer_raw(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=4096): def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
super(MOELayer_einsum, self).__init__() super(MOELayer_raw, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -71,23 +74,29 @@ class MOELayer_einsum(nn.Module): ...@@ -71,23 +74,29 @@ class MOELayer_einsum(nn.Module):
return x return x
batch_size = 4 batch_size = 4
num_expert = 4 num_expert = 8
in_feat = 2 in_feat = 2
out_feat = 3 out_feat = 3
moe = MOELayer(num_expert, in_feat, out_feat).cuda() moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_einsum = MOELayer_einsum(num_expert, in_feat, out_feat).cuda() moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
moe_einsum.weight.data = moe.weight.data.clone() moe_raw.weight.data = moe.weight.data.clone()
inp = torch.rand(batch_size, in_feat).cuda() inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
output = moe(inp, gate) output = moe(inp, gate)
output_einsum = moe_einsum(inp.clone(), gate.clone()) output_raw = moe_raw(inp.clone(), gate.clone())
print(output) #print(output)
print(output_einsum) #print(output_raw)
#y = output.mean() y = output.mean()
#y.backward() y.backward()
\ No newline at end of file
y_raw = output_raw.mean()
y_raw.backward()
print(moe.weight.grad)
print(moe_raw.weight.grad)
...@@ -203,7 +203,7 @@ void moe_cuda_grad_weight( ...@@ -203,7 +203,7 @@ void moe_cuda_grad_weight(
checkCudaErrors(cublasSetStream(h->handle, *(h->streams + gate_host[i]))); checkCudaErrors(cublasSetStream(h->handle, *(h->streams + gate_host[i])));
checkCudaErrors(cublasXgemm(h->handle, checkCudaErrors(cublasXgemm(h->handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_T,
out_feat, out_feat,
in_feat, in_feat,
1, 1,
...@@ -211,7 +211,7 @@ void moe_cuda_grad_weight( ...@@ -211,7 +211,7 @@ void moe_cuda_grad_weight(
grad_output + i * out_feat, grad_output + i * out_feat,
out_feat, out_feat,
input + i * in_feat, input + i * in_feat,
1, in_feat,
&beta, &beta,
grad_weight + gate_host[i] * out_feat * in_feat, grad_weight + gate_host[i] * out_feat * in_feat,
out_feat)); out_feat));
...@@ -229,7 +229,7 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -229,7 +229,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
auto output = input.new_zeros({batch_size, out_feat}); auto output = input.new_zeros({batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
...@@ -259,6 +259,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -259,6 +259,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
...@@ -285,8 +286,8 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -285,8 +286,8 @@ std::vector<torch::Tensor> moe_cuda_backward(
grad_output.data_ptr<scalar_t>(), grad_output.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(), grad_weight.data_ptr<scalar_t>(),
batch_size, batch_size,
out_feat,
in_feat, in_feat,
out_feat,
num_expert num_expert
); );
})); }));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment