Commit 284f1424 authored by Rick Ho's avatar Rick Ho
Browse files

degrade to single fc fwd

parent d690c7b2
......@@ -7,8 +7,7 @@
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input,
torch::Tensor gate,
torch::Tensor weight1,
torch::Tensor weight2);
torch::Tensor weight);
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output,
......@@ -26,19 +25,17 @@ std::vector<torch::Tensor> moe_cuda_backward(
std::vector<torch::Tensor> moe_forward(
torch::Tensor input, // [batch_size x in_feat]
torch::Tensor gate, // [batch_size]
torch::Tensor weight1, // [num_expert x hidden_feat x in_feat]
torch::Tensor weight2 // [num_expert x out_feat x hidden_feat]
torch::Tensor weight // [num_expert x hidden_feat x in_feat]
) {
CHECK_INPUT(input);
CHECK_INPUT(gate);
CHECK_INPUT(weight1);
CHECK_INPUT(weight2);
CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return moe_cuda_forward(input, gate, weight1, weight2);
return moe_cuda_forward(input, gate, weight);
}
std::vector<torch::Tensor> moe_backward(
......
......@@ -10,11 +10,11 @@ torch.cuda.manual_seed(42)
class MOEFunction(Function):
@staticmethod
def forward(ctx, inp, gate, weight1, weight2):
def forward(ctx, inp, gate, weight):
# out_feat, in_feat = weight.size()[1:]
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
output = moe_cuda.forward(inp, gate, weight1, weight2)
variables = [inp, gate, weight1, weight2]
output = moe_cuda.forward(inp, gate, weight)
variables = [inp, gate, weight]
ctx.save_for_backward(*variables)
return output[0]
......@@ -32,59 +32,46 @@ class MOEFunction(Function):
class MOELayer(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, out_feat=1024):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
super(MOELayer, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.hidden_feat = hidden_feat
self.out_feat = out_feat
self.weight1 = nn.Parameter(
torch.Tensor(num_expert, hidden_feat, in_feat))
self.weight2 = nn.Parameter(
torch.Tensor(num_expert, out_feat, hidden_feat))
self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters()
def reset_parameters(self):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
self.weight1.data[i] = linear.weight.data
linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
self.weight2.data[i] = linear.weight.data
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
self.weight.data[i] = linear.weight.data
def forward(self, inp, gate):
return MOEFunction.apply(inp, gate, self.weight1, self.weight2)
return MOEFunction.apply(inp, gate, self.weight)
class MOELayer_raw(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, out_feat=1024):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
super(MOELayer_raw, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.hidden_feat = hidden_feat
self.out_feat = out_feat
self.weight1 = nn.Parameter(
torch.Tensor(num_expert, hidden_feat, in_feat))
self.weight2 = nn.Parameter(
torch.Tensor(num_expert, out_feat, hidden_feat))
self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters()
def reset_parameters(self):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
print(linear.weight.shape)
self.weight1.data[i] = linear.weight.data
linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
self.weight2.data[i] = linear.weight.data
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
# print(linear.weight.shape)
self.weight.data[i] = linear.weight.data
def forward(self, inp, gate):
gate_long = gate.long()
batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat))
print(self.weight2)
for i in range(batch_size):
hid = inp[i] @ self.weight1[gate_long[i]].t()
print(hid)
x[i] = hid @ self.weight2[gate_long[i]].t()
x[i] = inp[i] @ self.weight[gate_long[i]].t()
return x
......@@ -105,15 +92,13 @@ def test():
batch_size = 4
num_expert = 2
in_feat = 6
hidden_feat = 12
out_feat = 7
linear = nn.Linear(in_feat, in_feat).cuda()
moe = MOELayer(num_expert, in_feat, hidden_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, hidden_feat, out_feat).cuda()
moe_raw.weight1.data = moe.weight1.data.clone()
moe_raw.weight2.data = moe.weight2.data.clone()
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda()
moe_raw.weight.data = moe.weight.data.clone()
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
......
......@@ -58,12 +58,10 @@ template <typename scalar_t>
void moe_cuda_forward_impl(
const scalar_t* input,
const int* d_gate,
const scalar_t* weight1,
const scalar_t* weight2,
const scalar_t* weight,
scalar_t* output,
const size_t batch_size,
const size_t in_feat,
const size_t hidden_feat,
const size_t out_feat,
const size_t num_expert) {
......@@ -73,14 +71,12 @@ void moe_cuda_forward_impl(
timestamp(t_init);
#endif
scalar_t *input_buf, *hidden_buf, *output_buf;
scalar_t *input_buf, *output_buf;
checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
in_feat));
checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
out_feat));
checkCudaErrors(cudaMalloc(&hidden_buf, sizeof(scalar_t) * batch_size *
hidden_feat));
#ifdef MOE_BREAKDOWN
timestamp(t_malloc);
......@@ -152,22 +148,11 @@ void moe_cuda_forward_impl(
checkCudaErrors(cublasXgemm(h->getHandle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
hidden_feat, expert_count[i], in_feat,
out_feat, expert_count[i], in_feat,
&alpha,
weight1 + i * in_feat * hidden_feat, in_feat,
weight + i * in_feat * out_feat, in_feat,
input_buf + ptr * in_feat, in_feat,
&beta,
hidden_buf + hidden_feat * ptr, hidden_feat
));
checkCudaErrors(cublasXgemm(h->getHandle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count[i], hidden_feat,
&alpha,
weight2 + i * hidden_feat * out_feat, hidden_feat,
hidden_buf + hidden_feat * ptr, hidden_feat,
&beta,
output_buf + out_feat * ptr, out_feat
));
......@@ -195,7 +180,6 @@ void moe_cuda_forward_impl(
#endif
cudaFree(input_buf);
cudaFree(hidden_buf);
cudaFree(output_buf);
cudaFree(d_pos);
delete [] pos;
......@@ -244,17 +228,15 @@ void moe_cuda_grad_weight(
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input,
torch::Tensor gate,
torch::Tensor weight1,
torch::Tensor weight2
torch::Tensor weight
) {
const auto batch_size = input.size(0);
const auto num_expert = weight1.size(0);
const auto out_feat = weight2.size(1);
const auto hidden_feat = weight1.size(1);
const auto in_feat = weight1.size(2);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, hidden_feat = %ld,out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, hidden_feat, out_feat);
printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
#endif
auto output = input.new_zeros({batch_size, out_feat});
......@@ -262,12 +244,10 @@ std::vector<torch::Tensor> moe_cuda_forward(
moe_cuda_forward_impl<scalar_t>(
input.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
weight1.data_ptr<scalar_t>(),
weight2.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
batch_size,
in_feat,
hidden_feat,
out_feat,
num_expert
);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment