Commit c5f73a0f authored by TiagoMAntunes's avatar TiagoMAntunes
Browse files

Bias moved to CUDA in forward. Basic bias setup for backwards (no kernel yet)

parent 6cdb3cda
...@@ -32,55 +32,44 @@ std::vector<torch::Tensor> moe_local_gather( ...@@ -32,55 +32,44 @@ std::vector<torch::Tensor> moe_local_gather(
return moe_cuda_local_gather(output_buf, pos); return moe_cuda_local_gather(output_buf, pos);
} }
void merge_bias(torch::Tensor &input_buf, torch::Tensor &weight, at::optional<torch::Tensor> bias_o) {
torch::Tensor bias = bias_o.value();
weight = at::cat({weight, bias.unsqueeze(2)}, 2); // [W b]
auto options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
auto ones = at::ones(input_buf.size(0), options).unsqueeze(1);
input_buf = at::cat({input_buf, ones}, 1); // [X 1]
}
std::vector<torch::Tensor> moe_forward( std::vector<torch::Tensor> moe_forward(
torch::Tensor input_buf, // [batch_size x in_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor expert_count, // [batch_size] torch::Tensor expert_count, // [num_expert]
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) { ) {
// Wx+b = [W b] [x]
// [1]
if (bias_o.has_value()) merge_bias(input_buf, weight, bias_o);
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
CHECK_INPUT(weight); CHECK_INPUT(weight);
return moe_cuda_forward(input_buf, expert_count, weight); // check if bias is valid in case it exists
if (bias_o.has_value()) {
auto bias = bias_o.value();
CHECK_INPUT(bias);
}
return moe_cuda_forward(input_buf, expert_count, weight, bias_o);
} }
std::vector<torch::Tensor> moe_backward( std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x in_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor expert_count, torch::Tensor expert_count, // [num_expert]
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) { ) {
// Wx+b = [W b] [x]
// [1]
if (bias_o.has_value()) merge_bias(input_buf, weight, bias_o);
CHECK_INPUT(grad_output_buf); CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
CHECK_INPUT(weight); CHECK_INPUT(weight);
return moe_cuda_backward(grad_output_buf, input_buf, expert_count, weight, bias_o.has_value()); // check if bias is valid in case it exists
if (bias_o.has_value()) {
auto bias = bias_o.value();
CHECK_INPUT(bias);
}
return moe_cuda_backward(grad_output_buf, input_buf, expert_count, weight, bias_o);
} }
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
......
...@@ -118,11 +118,12 @@ void moe_cuda_forward_impl( ...@@ -118,11 +118,12 @@ void moe_cuda_forward_impl(
const scalar_t* weight, const scalar_t* weight,
const long* expert_count, const long* expert_count,
scalar_t* output_buf, scalar_t* output_buf,
const bool has_bias,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert, const size_t num_expert,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = has_bias ? 1 : 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) { if (expert_count[i] == 0) {
...@@ -154,6 +155,8 @@ void moe_cuda_backward_impl( ...@@ -154,6 +155,8 @@ void moe_cuda_backward_impl(
const long* expert_count, const long* expert_count,
scalar_t* grad_input_buf, scalar_t* grad_input_buf,
scalar_t* grad_weight, scalar_t* grad_weight,
scalar_t* grad_bias,
const bool has_bias,
const size_t batch_size, const size_t batch_size,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
...@@ -195,6 +198,10 @@ void moe_cuda_backward_impl( ...@@ -195,6 +198,10 @@ void moe_cuda_backward_impl(
grad_weight + i * in_feat * out_feat, in_feat grad_weight + i * in_feat * out_feat, in_feat
)); ));
if (has_bias) {
// call bias kernel here
}
ptr += expert_count[i]; ptr += expert_count[i];
} }
smgr->sync(num_expert); smgr->sync(num_expert);
...@@ -276,7 +283,8 @@ std::vector<torch::Tensor> moe_cuda_local_gather( ...@@ -276,7 +283,8 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count, torch::Tensor expert_count,
torch::Tensor weight torch::Tensor weight,
at::optional<torch::Tensor> bias
) { ) {
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0); const auto batch_size = input_buf.size(0);
...@@ -288,10 +296,17 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -288,10 +296,17 @@ std::vector<torch::Tensor> moe_cuda_forward(
printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat); num_expert, in_feat, out_feat);
#endif #endif
torch::Tensor output;
if (bias.has_value()) {
output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0);
} else{
auto out_options = torch::TensorOptions() auto out_options = torch::TensorOptions()
.device(input_buf.device()) .device(input_buf.device())
.dtype(input_buf.dtype()); .dtype(input_buf.dtype());
auto output = torch::empty({batch_size, out_feat}, out_options); output = torch::empty({batch_size, out_feat}, out_options);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda", AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda",
([&] { ([&] {
...@@ -300,6 +315,7 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -300,6 +315,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(), expert_count.data_ptr<long>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
bias.has_value(),
in_feat, in_feat,
out_feat, out_feat,
num_expert, num_expert,
...@@ -315,7 +331,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -315,7 +331,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor input_buf, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor expert_count, torch::Tensor expert_count,
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor weight, // [num_expert x out_feat x in_feat]
bool has_bias at::optional<torch::Tensor> bias
) { ) {
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0); const auto batch_size = input_buf.size(0);
...@@ -331,6 +347,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -331,6 +347,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat}); auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_backward_impl<scalar_t>( moe_cuda_backward_impl<scalar_t>(
...@@ -340,6 +357,8 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -340,6 +357,8 @@ std::vector<torch::Tensor> moe_cuda_backward(
expert_count.data_ptr<long>(), expert_count.data_ptr<long>(),
grad_input_buf.data_ptr<scalar_t>(), grad_input_buf.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(), grad_weight.data_ptr<scalar_t>(),
grad_bias.data_ptr<scalar_t>(),
bias.has_value(),
batch_size, batch_size,
in_feat, in_feat,
out_feat, out_feat,
...@@ -348,17 +367,5 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -348,17 +367,5 @@ std::vector<torch::Tensor> moe_cuda_backward(
); );
})); }));
if (!has_bias) return {grad_input_buf, grad_weight, torch::empty({num_expert,out_feat})}; return {grad_input_buf, grad_weight, grad_bias};
// weight and input have been concatenated. need to split the grads back
// and separate them into input, weight, bias
torch::Tensor grad_orig_input_buf = at::narrow(grad_input_buf, -1, 0, in_feat - 1).contiguous();
// bias is also squeezed in the new added dimension
torch::Tensor grad_orig_bias = at::narrow(grad_weight, -1, in_feat - 1, 1).squeeze(2).contiguous();
torch::Tensor grad_orig_weight = at::narrow(grad_weight, -1, 0, in_feat - 1).contiguous();
return {grad_orig_input_buf, grad_orig_weight, grad_orig_bias};
} }
...@@ -20,14 +20,15 @@ std::vector<torch::Tensor> moe_cuda_local_gather( ...@@ -20,14 +20,15 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count, torch::Tensor expert_count,
torch::Tensor weight); torch::Tensor weight,
at::optional<torch::Tensor> bias);
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf, torch::Tensor grad_output_buf,
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count, torch::Tensor expert_count,
torch::Tensor weight, torch::Tensor weight,
bool has_bias); at::optional<torch::Tensor> bias);
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment