Commit 303d0e93 authored by TiagoMAntunes's avatar TiagoMAntunes
Browse files

Bias now being calculated directly in MOELinear layer. Added corresponding...

Bias now being calculated directly in MOELinear layer. Added corresponding CUDA changes. Updated forward and backward functions of MOELinear
parent 8bac18dc
......@@ -33,36 +33,54 @@ std::vector<torch::Tensor> moe_local_gather(
}
void merge_bias(torch::Tensor &input_buf, torch::Tensor &weight, at::optional<torch::Tensor> bias_o) {
torch::Tensor bias = bias_o.value();
weight = at::cat({weight, bias.unsqueeze(2)}, 2); // [W b]
auto options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
auto ones = at::ones(input_buf.size(0), options).unsqueeze(1);
input_buf = at::cat({input_buf, ones}, 1); // [X 1]
}
std::vector<torch::Tensor> moe_forward(
torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor expert_count, // [batch_size]
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count // [batch_size]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) {
// Wx+b = [W b] [x]
// [1]
if (bias_o.has_value()) merge_bias(input_buf, weight, bias_o);
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return moe_cuda_forward(input_buf, weight, expert_count);
return moe_cuda_forward(input_buf, expert_count, weight);
}
std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor expert_count,
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) {
// Wx+b = [W b] [x]
// [1]
if (bias_o.has_value()) merge_bias(input_buf, weight, bias_o);
CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return moe_cuda_backward(grad_output_buf, input_buf, weight, expert_count);
return moe_cuda_backward(grad_output_buf, input_buf, expert_count, weight, bias_o.has_value());
}
#ifdef MOE_USE_NCCL
......
......@@ -275,8 +275,8 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count
torch::Tensor expert_count,
torch::Tensor weight
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
......@@ -313,8 +313,9 @@ std::vector<torch::Tensor> moe_cuda_forward(
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor expert_count,
torch::Tensor weight, // [num_expert x out_feat x in_feat]
torch::Tensor expert_count
bool has_bias
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
......@@ -347,5 +348,17 @@ std::vector<torch::Tensor> moe_cuda_backward(
);
}));
return {grad_input_buf, grad_weight};
if (!has_bias) return {grad_input_buf, grad_weight, torch::empty({num_expert,out_feat})};
// weight and input have been concatenated. need to split the grads back
// and separate them into input, weight, bias
torch::Tensor grad_orig_input_buf = at::narrow(grad_input_buf, -1, 0, in_feat - 1).contiguous();
// bias is also squeezed in the new added dimension
torch::Tensor grad_orig_bias = at::narrow(grad_weight, -1, in_feat - 1, 1).squeeze(2).contiguous();
torch::Tensor grad_orig_weight = at::narrow(grad_weight, -1, 0, in_feat - 1).contiguous();
return {grad_orig_input_buf, grad_orig_weight, grad_orig_bias};
}
......@@ -19,14 +19,15 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor expert_count);
torch::Tensor expert_count,
torch::Tensor weight);
std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf,
torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight,
torch::Tensor expert_count);
bool has_bias);
#ifdef MOE_USE_NCCL
......
......@@ -112,19 +112,23 @@ class MOELinear(Function):
@staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
(global_output_buf,) = fmoe_cuda.forward(
global_input_buf, weight, fwd_expert_count
global_input_buf, fwd_expert_count, weight, bias
)
variables = (global_input_buf, fwd_expert_count, weight)
variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight) = ctx.saved_tensors
grad_inp_buf, grad_weight = fmoe_cuda.backward(
grad_out, input_buf, weight, fwd_expert_count
(input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.backward(
grad_out, input_buf, fwd_expert_count, weight, bias
)
return grad_inp_buf, None, grad_weight
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class MOEGather(Function):
......
......@@ -41,37 +41,7 @@ class FMoELinear(nn.Module):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, fwd_expert_count, self.weight)
if self.bias is not None:
# TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect
# gradient computation for solution 1 and 2.
# Solution 3 uses a for-loop to expand the bias,
# but is 50% slower.
# This part should finally goes to MOELinear.apply,
# like MOELinear.apply(x, weight, bias, count)
# Solution 1
bias = torch.repeat_interleave(
self.bias, fwd_expert_count.to(self.bias.device), dim=0
)
# Solution 2
# bias_idx = torch.arange(self.num_expert)\
# .repeat_interleave(fwd_expert_count)
# bias = self.bias[bias_idx]
# Solution 3
# bias = []
# for i in range(self.num_expert):
# if fwd_expert_count[i] > 0:
# bias.append(
# self.bias[i].unsqueeze(0).expand(
# fwd_expert_count[i], -1
# )
# )
# bias = torch.cat(bias, dim=0)
x = x + bias
x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
return x
def extra_repr(self) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment