/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include #include "../common.h" #ifdef __HIP_PLATFORM_AMD__ using __nv_fp8_e4m3 = te_hip_fp8_e4m3; using __nv_fp8_e5m2 = te_hip_fp8_e5m2; #define __ldlu(x) __ldg(x) #endif static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map, const int num_rows, const int topK, const int num_out_tokens) { // Each block corresponds to one source token // row_id_map[topK][num_rows] const int bid = blockIdx.x; const int tid = threadIdx.x; const int idx = bid * blockDim.x + tid; if (idx >= num_rows * topK) return; int source_row = sorted_row_id[idx]; int source_token_id = source_row / topK; int source_topK_id = source_row % topK; if (idx >= num_out_tokens) { // Set the indices of dropped tokens to -1 row_id_map[source_topK_id * num_rows + source_token_id] = -1; } else { // Create a row id map for subsequent unpermute operation row_id_map[source_topK_id * num_rows + source_token_id] = idx; } } template __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const int *row_id_map, const float *prob, const int num_rows, const int topK, const int num_cols) { extern __shared__ int8_t s_mem[]; TCompute *s_prob = reinterpret_cast(s_mem); // Each block corresponds to one dest token const int source_token = blockIdx.x; const int tid = threadIdx.x; if (hasProb) { for (int i = tid; i < topK; i += blockDim.x * blockDim.y) { // Load all the topK probs related to the source row into smem s_prob[i] = TCompute(prob[source_token * topK + i]); } __syncthreads(); } // Register buffers for vector type (float4) memory access float4 frag_load_store; T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); // Number of elemments in frag_load_store static constexpr int kElementsPerAccess = 16 / sizeof(T); // Traverse along the hidden dimention for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { TCompute frag_elem[kElementsPerAccess]; TCompute frag_sum[kElementsPerAccess]; int source_row = row_id_map[source_token]; // source_row == -1 represents a dropped token if (source_row != -1) { const T *source_row_ptr = input + source_row * num_cols; frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); for (int e = 0; e < kElementsPerAccess; e++) { frag_sum[e] = TCompute(frag_load_store_ptr[e]); } if (hasProb) { for (int e = 0; e < kElementsPerAccess; e++) { frag_sum[e] = frag_sum[e] * s_prob[0]; } } } else { for (int e = 0; e < kElementsPerAccess; e++) { frag_sum[e] = TCompute(0.0f); } } for (int k = 1; k < topK; k++) { source_row = row_id_map[k * num_rows + source_token]; if (source_row == -1) continue; const T *source_row_ptr = input + source_row * num_cols; frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); for (int e = 0; e < kElementsPerAccess; e++) { frag_elem[e] = TCompute(frag_load_store_ptr[e]); } if (hasProb) { for (int e = 0; e < kElementsPerAccess; e++) { frag_elem[e] = frag_elem[e] * s_prob[k]; } } for (int e = 0; e < kElementsPerAccess; e++) { frag_sum[e] = frag_sum[e] + frag_elem[e]; } } T *dest_row_ptr = unpermuted_output + source_token * num_cols; for (int e = 0; e < kElementsPerAccess; e++) { if constexpr ((std::is_same_v || std::is_same_v) && (!hasProb)) { frag_sum[e] = frag_sum[e] / TCompute(topK); } frag_load_store_ptr[e] = T(frag_sum[e]); } *reinterpret_cast(dest_row_ptr + i) = frag_load_store; } } template __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *act_grad, const float *prob, float *prob_grad, const int *row_id_map, const int num_rows, const int topK, const int num_cols) { extern __shared__ int8_t s_mem[]; TCompute *s_prob = reinterpret_cast(s_mem); // Each block corresponds to one source token const int source_token = blockIdx.x; const int tid = threadIdx.x; if (hasProb) { for (int i = tid; i < topK; i += blockDim.x) { // Load all the topK probs related to the source row into smem s_prob[i] = TCompute(prob[source_token * topK + i]); } __syncthreads(); } // Accumulators for the calculation of prob_grad float accum[topKTile] = {0.0f}; // Register buffers for vector type (float4) memory access float4 frag_load_store; T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); // Number of elemments in frag_load_store static constexpr int kElementsPerAccess = 16 / sizeof(T); // The starting address of each source row const T *source_row_ptr = input_bwd + source_token * num_cols; // Traverse along the hidden dimention for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { TCompute frag_src[kElementsPerAccess]; frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); for (int e = 0; e < kElementsPerAccess; e++) frag_src[e] = TCompute(frag_load_store_ptr[e]); int index = source_token; // Process each row in the corresponding topK rows for (int k = 0; k < topKTile; k++) { if (k == topK) break; int dest_row = row_id_map[index]; index += num_rows; if (dest_row != -1) { if (hasProb) { // Calculate act_grad in unpermute bwd for (int e = 0; e < kElementsPerAccess; e++) frag_load_store_ptr[e] = T(frag_src[e] * s_prob[k]); } else { // permute fwd for (int e = 0; e < kElementsPerAccess; e++) frag_load_store_ptr[e] = T(frag_src[e]); } T *dest_row_ptr = act_grad + dest_row * num_cols; *reinterpret_cast(dest_row_ptr + i) = frag_load_store; if (hasProb) { // Inner product calculation for prob_grad in unpermute bwd const T *input_fwd_ptr = input_fwd + dest_row * num_cols; frag_load_store = __ldlu(reinterpret_cast(input_fwd_ptr + i)); TCompute frag_input_fwd[kElementsPerAccess]; for (int e = 0; e < kElementsPerAccess; e++) frag_input_fwd[e] = TCompute(frag_load_store_ptr[e]); for (int e = 0; e < kElementsPerAccess; e++) { accum[k] += static_cast(frag_src[e] * frag_input_fwd[e]); } } } } } if (hasProb) { for (int k = 0; k < topKTile; k++) { if (k == topK) break; // Warp-level reduction for (int mask = 16; mask > 0; mask /= 2) { #ifdef __HIP_PLATFORM_AMD__ accum[k] = accum[k] + __shfl_xor(accum[k], mask, 32); #else accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32); #endif } } if (tid == 0) { for (int k = 0; k < topKTile; k++) { if (k == topK) break; prob_grad[source_token * topK + k] = accum[k]; } } } } template void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, int *row_id_map, const float *prob, float *prob_grad, const T *input_fwd, const int num_rows, const int topK, const int num_cols, const int num_out_tokens, cudaStream_t stream) { using TCompute = typename std::conditional<(std::is_same::value || std::is_same::value), half, T>::type; static constexpr int kElementsPerAccess = 16 / sizeof(T); if (input_fwd == nullptr) { // moe_permute_fwd int threads = 64; int blocks = (num_rows * topK + threads - 1) / threads; moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, num_out_tokens); blocks = num_rows; #ifdef __HIP_PLATFORM_AMD__ threads = std::min(num_cols / kElementsPerAccess, 256); #else threads = std::min(num_cols / kElementsPerAccess, 1024); #endif moe_permute_kernel<<>>( input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); } else { // moe_unpermute_bwd int threads = 32; int blocks = num_rows; if (prob == nullptr) { // moe_unpermute_bwd without probs moe_permute_kernel<<>>( input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); } else { // moe_unpermute_bwd with probs size_t smem_bytes = topK * sizeof(TCompute); if (topK <= 8) { moe_permute_kernel<<>>( input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); } else if (topK <= 16) { moe_permute_kernel<<>>( input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); } else if (topK <= 32) { moe_permute_kernel<<>>( input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); } else if (topK <= 64) { moe_permute_kernel<<>>( input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); } else if (topK <= 128) { moe_permute_kernel<<>>( input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); } else { NVTE_ERROR("topK cannot exceed 128."); } } } } template void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const float *prob, const int num_rows, const int topK, const int num_cols, cudaStream_t stream) { using TCompute = typename std::conditional<(std::is_same::value || std::is_same::value), half, T>::type; static constexpr int kElementsPerAccess = 16 / sizeof(T); int blocks = num_rows; #ifdef __HIP_PLATFORM_AMD__ int threads = std::min(num_cols / kElementsPerAccess, 256); #else int threads = std::min(num_cols / kElementsPerAccess, 1024); #endif size_t smem_bytes = topK * sizeof(TCompute); if (prob == nullptr) { // moe_permute_bwd // moe_unpermute_fwd without probs moe_unpermute_kernel<<>>( input, output, row_id_map, nullptr, num_rows, topK, num_cols); } else { // moe_unpermute_fwd with probs moe_unpermute_kernel<<>>( input, output, row_id_map, prob, num_rows, topK, num_cols); } } void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id, NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, const NVTETensor input_fwd, const int num_rows, const int topK, const int num_cols, const int num_out_tokens, cudaStream_t stream) { NVTE_API_CALL(nvte_permute); const transformer_engine::Tensor *input_cu = reinterpret_cast(input); const transformer_engine::Tensor *output_cu = reinterpret_cast(output); const transformer_engine::Tensor *sorted_row_id_cu = reinterpret_cast(sorted_row_id); const transformer_engine::Tensor *row_id_map_cu = reinterpret_cast(row_id_map); const transformer_engine::Tensor *prob_cu = reinterpret_cast(prob); const transformer_engine::Tensor *prob_grad_cu = reinterpret_cast(prob_grad); const transformer_engine::Tensor *input_fwd_cu = reinterpret_cast(input_fwd); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( input_cu->data.dtype, T, nvte_permute_launcher(reinterpret_cast(input_cu->data.dptr), reinterpret_cast(output_cu->data.dptr), reinterpret_cast(sorted_row_id_cu->data.dptr), reinterpret_cast(row_id_map_cu->data.dptr), reinterpret_cast(prob_cu->data.dptr), reinterpret_cast(prob_grad_cu->data.dptr), reinterpret_cast(input_fwd_cu->data.dptr), num_rows, topK, num_cols, num_out_tokens, stream);); } void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, const NVTETensor prob, const int num_rows, const int topK, const int num_cols, cudaStream_t stream) { NVTE_API_CALL(nvte_unpermute); const transformer_engine::Tensor *input_cu = reinterpret_cast(input); const transformer_engine::Tensor *output_cu = reinterpret_cast(output); const transformer_engine::Tensor *row_id_map_cu = reinterpret_cast(row_id_map); const transformer_engine::Tensor *prob_cu = reinterpret_cast(prob); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( input_cu->data.dtype, T, nvte_unpermute_launcher(reinterpret_cast(input_cu->data.dptr), reinterpret_cast(output_cu->data.dptr), reinterpret_cast(row_id_map_cu->data.dptr), reinterpret_cast(prob_cu->data.dptr), num_rows, topK, num_cols, stream);); }