Commit 32e35812 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

update

parent d4dd2a6c
...@@ -10,18 +10,18 @@ torch.cuda.manual_seed(42) ...@@ -10,18 +10,18 @@ torch.cuda.manual_seed(42)
class MOEFunction(Function): class MOEFunction(Function):
@staticmethod @staticmethod
def forward(ctx, input, gate, weight): def forward(ctx, inp, gate, weight):
output = moe_cuda.forward(input, gate, weight) output = moe_cuda.forward(inp, gate, weight)
variables = [input, gate, weight] variables = [inp, gate, weight]
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return output[0] return output[0]
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
grad_input, grad_weight = moe_cuda.backward( grad_inp, grad_weight = moe_cuda.backward(
grad_out.contiguous(), *ctx.saved_tensors) grad_out.contiguous(), *ctx.saved_tensors)
return grad_input, None, grad_weight return grad_inp, None, grad_weight
class MOELayer(nn.Module): class MOELayer(nn.Module):
...@@ -39,8 +39,8 @@ class MOELayer(nn.Module): ...@@ -39,8 +39,8 @@ class MOELayer(nn.Module):
linear = nn.Linear(in_features=self.in_feat, out_features=out_feat) linear = nn.Linear(in_features=self.in_feat, out_features=out_feat)
self.weight.data[i] = linear.weight.data self.weight.data[i] = linear.weight.data
def forward(self, input, gate): def forward(self, inp, gate):
return MOEFunction.apply(input, gate, self.weight) return MOEFunction.apply(inp, gate, self.weight)
class MOELayer_einsum(nn.Module): class MOELayer_einsum(nn.Module):
...@@ -58,36 +58,37 @@ class MOELayer_einsum(nn.Module): ...@@ -58,36 +58,37 @@ class MOELayer_einsum(nn.Module):
linear = nn.Linear(in_features=self.in_feat, out_features=out_feat) linear = nn.Linear(in_features=self.in_feat, out_features=out_feat)
self.weight.data[i] = linear.weight.data self.weight.data[i] = linear.weight.data
def forward(self, input, gate): def forward(self, inp, gate):
gate_long = gate.long() gate_long = gate.long()
#W = self.weight[gate_long] # [batch_size x out_feat x in_feat] batch_size = inp.size(0)
#x = torch.einsum('id,ihd->ih', (input, W)) # [batch_size x out_feat] x = inp.new_zeros((batch_size, self.out_feat))
#return x
batch_size = input.size(0)
x = input.new_zeros((batch_size, self.out_feat))
for i in range(batch_size): for i in range(batch_size):
x[i] = self.weight[gate_long[i]] @ input[i] x[i] = self.weight[gate_long[i]] @ inp[i]
return x return x
batch_size = 2 batch_size = 1
num_expert = 2 num_expert = 1
in_feat = 2 in_feat = 3
out_feat = 4 out_feat = 3
moe = MOELayer(num_expert, in_feat, out_feat).cuda() moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_einsum = MOELayer_einsum(num_expert, in_feat, out_feat).cuda() moe_einsum = MOELayer_einsum(num_expert, in_feat, out_feat).cuda()
moe_einsum.weight.data = moe.weight.data.clone() moe_einsum.weight.data = moe.weight.data.clone()
input = torch.rand(batch_size, in_feat).cuda()
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
print(input) print(inp.type())
print(moe.weight.data.type())
print(inp)
print(gate) print(gate)
output = moe(input, gate) output = moe(inp, gate)
print(input) print(inp)
print(gate) print(gate)
output_einsum = moe_einsum(input, gate) output_einsum = moe_einsum(inp.clone(), gate.clone())
print(output) print(output)
print(output_einsum) print(output_einsum)
......
...@@ -161,9 +161,12 @@ void moe_cuda_forward_impl( ...@@ -161,9 +161,12 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*))); checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*)));
checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*))); checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*)));
int* gate_host = new int[batch_size];
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) { for (size_t i=0; i<batch_size; ++i) {
aptrs.push_back(input + in_feat * i); aptrs.push_back(input + in_feat * i);
bptrs.push_back(weight + out_feat * in_feat * i); bptrs.push_back(weight + out_feat * in_feat * gate_host[i]);
cptrs.push_back(output + out_feat * i); cptrs.push_back(output + out_feat * i);
} }
checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const scalar_t*), cudaMemcpyHostToDevice)); checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const scalar_t*), cudaMemcpyHostToDevice));
...@@ -177,9 +180,12 @@ void moe_cuda_forward_impl( ...@@ -177,9 +180,12 @@ void moe_cuda_forward_impl(
const scalar_t **B = (const scalar_t **)malloc(batch_size * sizeof(const scalar_t*)); const scalar_t **B = (const scalar_t **)malloc(batch_size * sizeof(const scalar_t*));
checkCudaErrors(cudaMemcpy(B, Barray, batch_size * sizeof(const scalar_t*), cudaMemcpyDeviceToHost)); checkCudaErrors(cudaMemcpy(B, Barray, batch_size * sizeof(const scalar_t*), cudaMemcpyDeviceToHost));
std::cout << weight << std::endl; std::cout << input << " " << weight << " " << output << std::endl;
for (size_t i=0; i<batch_size; ++i) { for (size_t i=0; i<batch_size; ++i) {
std::cout << B[i] << " " << bptrs[i] << std::endl; std::cout << i << std::endl;
std::cout << "A " << aptrs[i] << std::endl;
std::cout << "B " << B[i] << " " << bptrs[i] << std::endl;
std::cout << "C " << cptrs[i] << std::endl;
} }
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment