Unverified Commit a26a7f1f authored by Autumn1998's avatar Autumn1998 Committed by GitHub
Browse files

Add bf16/fp32 token-per-expert to the MoE aux loss kernel (#2162)



* add bf16/fp32 token-per-expert on the moe-loss-computation on router fusion
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 84fa28d2
...@@ -229,7 +229,7 @@ __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf, ...@@ -229,7 +229,7 @@ __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf,
// Loop: for all positions in each row // Loop: for all positions in each row
for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) {
float C_coeff = Const_buf[0]; float C_coeff = Const_buf[0];
IndexType tokens_per_expert_i = tokens_per_expert[i]; double tokens_per_expert_i = static_cast<double>(tokens_per_expert[i]);
double grad_aux_loss_value = static_cast<double>(grad_aux_loss[0]); double grad_aux_loss_value = static_cast<double>(grad_aux_loss[0]);
// Loop: for all rows // Loop: for all rows
for (int j = global_warp_id; j < num_rows; j += global_warp_num) { for (int j = global_warp_id; j < num_rows; j += global_warp_num) {
......
...@@ -246,6 +246,14 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i ...@@ -246,6 +246,14 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
using type = int64_t; \ using type = int64_t; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} break; \ } break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
default: \ default: \
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment