Commit 32e35812 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

update

parent d4dd2a6c
......@@ -10,18 +10,18 @@ torch.cuda.manual_seed(42)
class MOEFunction(Function):
@staticmethod
def forward(ctx, input, gate, weight):
output = moe_cuda.forward(input, gate, weight)
variables = [input, gate, weight]
def forward(ctx, inp, gate, weight):
output = moe_cuda.forward(inp, gate, weight)
variables = [inp, gate, weight]
ctx.save_for_backward(*variables)
return output[0]
@staticmethod
def backward(ctx, grad_out):
grad_input, grad_weight = moe_cuda.backward(
grad_inp, grad_weight = moe_cuda.backward(
grad_out.contiguous(), *ctx.saved_tensors)
return grad_input, None, grad_weight
return grad_inp, None, grad_weight
class MOELayer(nn.Module):
......@@ -39,8 +39,8 @@ class MOELayer(nn.Module):
linear = nn.Linear(in_features=self.in_feat, out_features=out_feat)
self.weight.data[i] = linear.weight.data
def forward(self, input, gate):
return MOEFunction.apply(input, gate, self.weight)
def forward(self, inp, gate):
return MOEFunction.apply(inp, gate, self.weight)
class MOELayer_einsum(nn.Module):
......@@ -58,36 +58,37 @@ class MOELayer_einsum(nn.Module):
linear = nn.Linear(in_features=self.in_feat, out_features=out_feat)
self.weight.data[i] = linear.weight.data
def forward(self, input, gate):
def forward(self, inp, gate):
gate_long = gate.long()
#W = self.weight[gate_long] # [batch_size x out_feat x in_feat]
#x = torch.einsum('id,ihd->ih', (input, W)) # [batch_size x out_feat]
#return x
batch_size = input.size(0)
x = input.new_zeros((batch_size, self.out_feat))
batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat))
for i in range(batch_size):
x[i] = self.weight[gate_long[i]] @ input[i]
x[i] = self.weight[gate_long[i]] @ inp[i]
return x
batch_size = 2
num_expert = 2
in_feat = 2
out_feat = 4
batch_size = 1
num_expert = 1
in_feat = 3
out_feat = 3
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_einsum = MOELayer_einsum(num_expert, in_feat, out_feat).cuda()
moe_einsum.weight.data = moe.weight.data.clone()
input = torch.rand(batch_size, in_feat).cuda()
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
print(input)
print(inp.type())
print(moe.weight.data.type())
print(inp)
print(gate)
output = moe(input, gate)
output = moe(inp, gate)
print(input)
print(inp)
print(gate)
output_einsum = moe_einsum(input, gate)
output_einsum = moe_einsum(inp.clone(), gate.clone())
print(output)
print(output_einsum)
......
......@@ -161,9 +161,12 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*)));
checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*)));
int* gate_host = new int[batch_size];
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) {
aptrs.push_back(input + in_feat * i);
bptrs.push_back(weight + out_feat * in_feat * i);
bptrs.push_back(weight + out_feat * in_feat * gate_host[i]);
cptrs.push_back(output + out_feat * i);
}
checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const scalar_t*), cudaMemcpyHostToDevice));
......@@ -177,9 +180,12 @@ void moe_cuda_forward_impl(
const scalar_t **B = (const scalar_t **)malloc(batch_size * sizeof(const scalar_t*));
checkCudaErrors(cudaMemcpy(B, Barray, batch_size * sizeof(const scalar_t*), cudaMemcpyDeviceToHost));
std::cout << weight << std::endl;
std::cout << input << " " << weight << " " << output << std::endl;
for (size_t i=0; i<batch_size; ++i) {
std::cout << B[i] << " " << bptrs[i] << std::endl;
std::cout << i << std::endl;
std::cout << "A " << aptrs[i] << std::endl;
std::cout << "B " << B[i] << " " << bptrs[i] << std::endl;
std::cout << "C " << cptrs[i] << std::endl;
}
scalar_t alpha = 1, beta = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment