"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "544ba677dd97a49c8124208837025aa8b5ab639e"
Commit c5f73a0f authored by TiagoMAntunes's avatar TiagoMAntunes
Browse files

Bias moved to CUDA in forward. Basic bias setup for backwards (no kernel yet)

parent 6cdb3cda
...@@ -32,55 +32,44 @@ std::vector<torch::Tensor> moe_local_gather( ...@@ -32,55 +32,44 @@ std::vector<torch::Tensor> moe_local_gather(
return moe_cuda_local_gather(output_buf, pos); return moe_cuda_local_gather(output_buf, pos);
} }
void merge_bias(torch::Tensor &input_buf, torch::Tensor &weight, at::optional<torch::Tensor> bias_o) {
torch::Tensor bias = bias_o.value();
weight = at::cat({weight, bias.unsqueeze(2)}, 2); // [W b]
auto options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
auto ones = at::ones(input_buf.size(0), options).unsqueeze(1);
input_buf = at::cat({input_buf, ones}, 1); // [X 1]
}
std::vector<torch::Tensor> moe_forward( std::vector<torch::Tensor> moe_forward(
torch::Tensor input_buf, // [batch_size x in_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor expert_count, // [batch_size] torch::Tensor expert_count, // [num_expert]
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) { ) {
// Wx+b = [W b] [x]
// [1]
if (bias_o.has_value()) merge_bias(input_buf, weight, bias_o);
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
CHECK_INPUT(weight); CHECK_INPUT(weight);
return moe_cuda_forward(input_buf, expert_count, weight); // check if bias is valid in case it exists
if (bias_o.has_value()) {
auto bias = bias_o.value();
CHECK_INPUT(bias);
}
return moe_cuda_forward(input_buf, expert_count, weight, bias_o);
} }
std::vector<torch::Tensor> moe_backward( std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x in_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor expert_count, torch::Tensor expert_count, // [num_expert]
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) { ) {
// Wx+b = [W b] [x]
// [1]
if (bias_o.has_value()) merge_bias(input_buf, weight, bias_o);
CHECK_INPUT(grad_output_buf); CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
CHECK_INPUT(weight); CHECK_INPUT(weight);
// check if bias is valid in case it exists
if (bias_o.has_value()) {
auto bias = bias_o.value();
CHECK_INPUT(bias);
}
return moe_cuda_backward(grad_output_buf, input_buf, expert_count, weight, bias_o.has_value()); return moe_cuda_backward(grad_output_buf, input_buf, expert_count, weight, bias_o);
} }
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
......
...@@ -118,11 +118,12 @@ void moe_cuda_forward_impl( ...@@ -118,11 +118,12 @@ void moe_cuda_forward_impl(
const scalar_t* weight, const scalar_t* weight,
const long* expert_count, const long* expert_count,
scalar_t* output_buf, scalar_t* output_buf,
const bool has_bias,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert, const size_t num_expert,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = has_bias ? 1 : 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) { if (expert_count[i] == 0) {
...@@ -154,6 +155,8 @@ void moe_cuda_backward_impl( ...@@ -154,6 +155,8 @@ void moe_cuda_backward_impl(
const long* expert_count, const long* expert_count,
scalar_t* grad_input_buf, scalar_t* grad_input_buf,
scalar_t* grad_weight, scalar_t* grad_weight,
scalar_t* grad_bias,
const bool has_bias,
const size_t batch_size, const size_t batch_size,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
...@@ -194,6 +197,10 @@ void moe_cuda_backward_impl( ...@@ -194,6 +197,10 @@ void moe_cuda_backward_impl(
&beta, &beta,
grad_weight + i * in_feat * out_feat, in_feat grad_weight + i * in_feat * out_feat, in_feat
)); ));
if (has_bias) {
// call bias kernel here
}
ptr += expert_count[i]; ptr += expert_count[i];
} }
...@@ -276,7 +283,8 @@ std::vector<torch::Tensor> moe_cuda_local_gather( ...@@ -276,7 +283,8 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count, torch::Tensor expert_count,
torch::Tensor weight torch::Tensor weight,
at::optional<torch::Tensor> bias
) { ) {
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0); const auto batch_size = input_buf.size(0);
...@@ -288,11 +296,18 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -288,11 +296,18 @@ std::vector<torch::Tensor> moe_cuda_forward(
printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat); num_expert, in_feat, out_feat);
#endif #endif
auto out_options = torch::TensorOptions()
.device(input_buf.device()) torch::Tensor output;
.dtype(input_buf.dtype());
auto output = torch::empty({batch_size, out_feat}, out_options);
if (bias.has_value()) {
output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0);
} else{
auto out_options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
output = torch::empty({batch_size, out_feat}, out_options);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda", AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda",
([&] { ([&] {
moe_cuda_forward_impl<scalar_t>( moe_cuda_forward_impl<scalar_t>(
...@@ -300,6 +315,7 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -300,6 +315,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(), expert_count.data_ptr<long>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
bias.has_value(),
in_feat, in_feat,
out_feat, out_feat,
num_expert, num_expert,
...@@ -315,7 +331,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -315,7 +331,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor input_buf, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor expert_count, torch::Tensor expert_count,
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor weight, // [num_expert x out_feat x in_feat]
bool has_bias at::optional<torch::Tensor> bias
) { ) {
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0); const auto batch_size = input_buf.size(0);
...@@ -331,6 +347,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -331,6 +347,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat}); auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_backward_impl<scalar_t>( moe_cuda_backward_impl<scalar_t>(
...@@ -340,6 +357,8 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -340,6 +357,8 @@ std::vector<torch::Tensor> moe_cuda_backward(
expert_count.data_ptr<long>(), expert_count.data_ptr<long>(),
grad_input_buf.data_ptr<scalar_t>(), grad_input_buf.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(), grad_weight.data_ptr<scalar_t>(),
grad_bias.data_ptr<scalar_t>(),
bias.has_value(),
batch_size, batch_size,
in_feat, in_feat,
out_feat, out_feat,
...@@ -348,17 +367,5 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -348,17 +367,5 @@ std::vector<torch::Tensor> moe_cuda_backward(
); );
})); }));
if (!has_bias) return {grad_input_buf, grad_weight, torch::empty({num_expert,out_feat})}; return {grad_input_buf, grad_weight, grad_bias};
// weight and input have been concatenated. need to split the grads back
// and separate them into input, weight, bias
torch::Tensor grad_orig_input_buf = at::narrow(grad_input_buf, -1, 0, in_feat - 1).contiguous();
// bias is also squeezed in the new added dimension
torch::Tensor grad_orig_bias = at::narrow(grad_weight, -1, in_feat - 1, 1).squeeze(2).contiguous();
torch::Tensor grad_orig_weight = at::narrow(grad_weight, -1, 0, in_feat - 1).contiguous();
return {grad_orig_input_buf, grad_orig_weight, grad_orig_bias};
} }
...@@ -20,14 +20,15 @@ std::vector<torch::Tensor> moe_cuda_local_gather( ...@@ -20,14 +20,15 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count, torch::Tensor expert_count,
torch::Tensor weight); torch::Tensor weight,
at::optional<torch::Tensor> bias);
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf, torch::Tensor grad_output_buf,
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count, torch::Tensor expert_count,
torch::Tensor weight, torch::Tensor weight,
bool has_bias); at::optional<torch::Tensor> bias);
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment