"src/vscode:/vscode.git/clone" did not exist on "61dec53356949ce6cd4e5bfbf64abf6ee7b5d5d9"
Commit 303d0e93 authored by TiagoMAntunes's avatar TiagoMAntunes
Browse files

Bias now being calculated directly in MOELinear layer. Added corresponding...

Bias now being calculated directly in MOELinear layer. Added corresponding CUDA changes. Updated forward and backward functions of MOELinear
parent 8bac18dc
...@@ -33,36 +33,54 @@ std::vector<torch::Tensor> moe_local_gather( ...@@ -33,36 +33,54 @@ std::vector<torch::Tensor> moe_local_gather(
} }
void merge_bias(torch::Tensor &input_buf, torch::Tensor &weight, at::optional<torch::Tensor> bias_o) {
torch::Tensor bias = bias_o.value();
weight = at::cat({weight, bias.unsqueeze(2)}, 2); // [W b]
auto options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
auto ones = at::ones(input_buf.size(0), options).unsqueeze(1);
input_buf = at::cat({input_buf, ones}, 1); // [X 1]
}
std::vector<torch::Tensor> moe_forward( std::vector<torch::Tensor> moe_forward(
torch::Tensor input_buf, // [batch_size x in_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor expert_count, // [batch_size]
torch::Tensor expert_count // [batch_size] torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) { ) {
// Wx+b = [W b] [x]
// [1]
if (bias_o.has_value()) merge_bias(input_buf, weight, bias_o);
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
CHECK_INPUT(weight); CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that return moe_cuda_forward(input_buf, expert_count, weight);
Wx+b = [W b] [x]
[1]
*/
return moe_cuda_forward(input_buf, weight, expert_count);
} }
std::vector<torch::Tensor> moe_backward( std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor expert_count,
torch::Tensor expert_count torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) { ) {
// Wx+b = [W b] [x]
// [1]
if (bias_o.has_value()) merge_bias(input_buf, weight, bias_o);
CHECK_INPUT(grad_output_buf); CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
CHECK_INPUT(weight); CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that return moe_cuda_backward(grad_output_buf, input_buf, expert_count, weight, bias_o.has_value());
Wx+b = [W b] [x]
[1]
*/
return moe_cuda_backward(grad_output_buf, input_buf, weight, expert_count);
} }
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
......
...@@ -275,8 +275,8 @@ std::vector<torch::Tensor> moe_cuda_local_gather( ...@@ -275,8 +275,8 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor weight, torch::Tensor expert_count,
torch::Tensor expert_count torch::Tensor weight
) { ) {
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0); const auto batch_size = input_buf.size(0);
...@@ -311,10 +311,11 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -311,10 +311,11 @@ std::vector<torch::Tensor> moe_cuda_forward(
} }
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor expert_count,
torch::Tensor expert_count torch::Tensor weight, // [num_expert x out_feat x in_feat]
bool has_bias
) { ) {
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0); const auto batch_size = input_buf.size(0);
...@@ -347,5 +348,17 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -347,5 +348,17 @@ std::vector<torch::Tensor> moe_cuda_backward(
); );
})); }));
return {grad_input_buf, grad_weight}; if (!has_bias) return {grad_input_buf, grad_weight, torch::empty({num_expert,out_feat})};
// weight and input have been concatenated. need to split the grads back
// and separate them into input, weight, bias
torch::Tensor grad_orig_input_buf = at::narrow(grad_input_buf, -1, 0, in_feat - 1).contiguous();
// bias is also squeezed in the new added dimension
torch::Tensor grad_orig_bias = at::narrow(grad_weight, -1, in_feat - 1, 1).squeeze(2).contiguous();
torch::Tensor grad_orig_weight = at::narrow(grad_weight, -1, 0, in_feat - 1).contiguous();
return {grad_orig_input_buf, grad_orig_weight, grad_orig_bias};
} }
...@@ -19,14 +19,15 @@ std::vector<torch::Tensor> moe_cuda_local_gather( ...@@ -19,14 +19,15 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor weight, torch::Tensor expert_count,
torch::Tensor expert_count); torch::Tensor weight);
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf, torch::Tensor grad_output_buf,
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight, torch::Tensor weight,
torch::Tensor expert_count); bool has_bias);
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
......
...@@ -112,19 +112,23 @@ class MOELinear(Function): ...@@ -112,19 +112,23 @@ class MOELinear(Function):
@staticmethod @staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None): def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
(global_output_buf,) = fmoe_cuda.forward( (global_output_buf,) = fmoe_cuda.forward(
global_input_buf, weight, fwd_expert_count global_input_buf, fwd_expert_count, weight, bias
) )
variables = (global_input_buf, fwd_expert_count, weight) variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return global_output_buf return global_output_buf
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight) = ctx.saved_tensors (input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight = fmoe_cuda.backward( grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.backward(
grad_out, input_buf, weight, fwd_expert_count grad_out, input_buf, fwd_expert_count, weight, bias
) )
return grad_inp_buf, None, grad_weight
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class MOEGather(Function): class MOEGather(Function):
......
...@@ -41,37 +41,7 @@ class FMoELinear(nn.Module): ...@@ -41,37 +41,7 @@ class FMoELinear(nn.Module):
r""" r"""
Call MOE function Call MOE function
""" """
x = MOELinear.apply(inp, fwd_expert_count, self.weight) x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
if self.bias is not None:
# TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect
# gradient computation for solution 1 and 2.
# Solution 3 uses a for-loop to expand the bias,
# but is 50% slower.
# This part should finally goes to MOELinear.apply,
# like MOELinear.apply(x, weight, bias, count)
# Solution 1
bias = torch.repeat_interleave(
self.bias, fwd_expert_count.to(self.bias.device), dim=0
)
# Solution 2
# bias_idx = torch.arange(self.num_expert)\
# .repeat_interleave(fwd_expert_count)
# bias = self.bias[bias_idx]
# Solution 3
# bias = []
# for i in range(self.num_expert):
# if fwd_expert_count[i] > 0:
# bias.append(
# self.bias[i].unsqueeze(0).expand(
# fwd_expert_count[i], -1
# )
# )
# bias = torch.cat(bias, dim=0)
x = x + bias
return x return x
def extra_repr(self) -> str: def extra_repr(self) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment