Unverified Commit 62acae05 authored by Autumn1998's avatar Autumn1998 Committed by GitHub
Browse files

[PyTorch][MoE] Kernels fusions for the MoE router (#1883)



* add router fusion
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix ci
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* fix ci with cuda 12.3
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix CI sm89/80
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatartongliu <tongliu@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 64891899
...@@ -46,6 +46,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_P ...@@ -46,6 +46,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_P
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import math
from typing import Optional, Dict
from transformer_engine.pytorch.router import (
fused_topk_with_score_function,
fused_compute_score_for_moe_aux_loss,
fused_moe_aux_loss,
)
import pytest
from copy import deepcopy
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Pytorch-based group topk
def group_limited_topk(
scores: torch.Tensor,
topk: int,
num_tokens: int,
num_experts: int,
num_groups: int,
group_topk: int,
):
group_scores = (
scores.view(num_tokens, num_groups, -1).topk(topk // group_topk, dim=-1)[0].sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
# Mask the experts based on selection groups
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_tokens, num_groups, num_experts // num_groups)
.reshape(num_tokens, -1)
)
masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1)
return probs, top_indices
# Pytorch-based topk softmax/sigmoid
def topk_softmax_sigmoid_pytorch(
logits: torch.Tensor,
topk: int,
use_pre_softmax: bool = False,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
scaling_factor: Optional[float] = None,
score_function: str = "softmax",
expert_bias: Optional[torch.Tensor] = None,
):
num_tokens, num_experts = logits.shape
def compute_topk(scores, topk, num_groups=None, group_topk=None):
if group_topk:
return group_limited_topk(
scores=scores,
topk=topk,
num_tokens=num_tokens,
num_experts=num_experts,
num_groups=num_groups,
group_topk=group_topk,
)
else:
return torch.topk(scores, k=topk, dim=1)
if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
else:
scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits.float()).type_as(logits)
if expert_bias is not None:
scores_for_routing = scores + expert_bias
_, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
else:
scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")
if scaling_factor:
probs = probs * scaling_factor
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
return topk_masked_gates, topk_map
# Pytorch-based compute routing scores for aux loss
def compute_scores_for_aux_loss_pytorch(
logits: torch.Tensor, topk: int, score_function: str
) -> torch.Tensor:
if score_function == "softmax":
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")
_, top_indices = torch.topk(scores, k=topk, dim=1)
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
return routing_map, scores
# Pytorch-based aux loss
def aux_loss_pytorch(
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
total_num_tokens: int,
topk: int,
num_experts: int,
moe_aux_loss_coeff: float,
):
aggregated_probs_per_expert = probs.sum(dim=0)
aux_loss = torch.sum(aggregated_probs_per_expert * tokens_per_expert) * (
num_experts * moe_aux_loss_coeff / (topk * total_num_tokens * total_num_tokens)
)
return aux_loss
def run_comparison(
dtype,
num_tokens,
num_experts,
topk,
use_pre_softmax,
num_groups,
group_topk,
scaling_factor,
score_function,
enable_bias,
):
# Set some parameters
if score_function == "sigmoid":
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 0.1
logits = logits.unsqueeze(0).repeat(num_tokens, 1)
else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
if enable_bias and score_function == "sigmoid":
expert_bias = torch.arange(num_experts, device="cuda") * 0.1
expert_bias.requires_grad = True
else:
expert_bias = None
# Clone the input tensor
logits_clone = deepcopy(logits)
logits_clone.requires_grad = True
if expert_bias is not None:
expert_bias_clone = deepcopy(expert_bias)
expert_bias_clone.requires_grad = True
else:
expert_bias_clone = None
# Run the original implementation
# We do not support the capacity factor case
probs, routing_map = topk_softmax_sigmoid_pytorch(
logits=logits,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function=score_function,
expert_bias=expert_bias,
)
# Run the fused implementation
probs_fused, routing_map_fused = fused_topk_with_score_function(
logits=logits_clone,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function=score_function,
expert_bias=expert_bias_clone,
)
torch.testing.assert_close(probs, probs_fused)
torch.testing.assert_close(routing_map, routing_map_fused)
# Fake the loss
loss = torch.sum(probs)
loss_fused = torch.sum(probs_fused)
# Backward the loss
loss.backward()
loss_fused.backward()
# Check the gradient
torch.testing.assert_close(logits.grad, logits_clone.grad)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
@pytest.mark.parametrize("enable_bias", [True, False])
def test_topk_sigmoid(
dtype,
num_tokens,
num_experts,
topk,
group_topk,
scaling_factor,
enable_bias,
):
num_groups = 8 if group_topk else None
run_comparison(
dtype=dtype,
num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
use_pre_softmax=False,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function="sigmoid",
enable_bias=enable_bias,
)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("use_pre_softmax", [True, False])
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
def test_topk_softmax(
dtype,
num_tokens,
num_experts,
topk,
use_pre_softmax,
group_topk,
scaling_factor,
):
num_groups = 8 if group_topk else None
run_comparison(
dtype=dtype,
num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function="softmax",
enable_bias=False,
)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"])
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
logits.requires_grad = True
logits_clone = deepcopy(logits)
logits_clone.requires_grad = True
routing_map, scores = compute_scores_for_aux_loss_pytorch(
logits=logits,
topk=topk,
score_function=score_function,
)
routing_map_fused, scores_fused = fused_compute_score_for_moe_aux_loss(
logits=logits_clone,
topk=topk,
score_function=score_function,
)
torch.testing.assert_close(scores, scores_fused)
torch.testing.assert_close(routing_map, routing_map_fused)
loss = torch.sum(scores)
loss.backward()
loss_fused = torch.sum(scores_fused)
loss_fused.backward()
torch.testing.assert_close(logits.grad, logits_clone.grad)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4])
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
probs = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
probs.requires_grad = True
tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32)
coeff = 0.01
probs_clone = deepcopy(probs)
probs_clone.requires_grad = True
aux_loss = aux_loss_pytorch(
probs=probs,
tokens_per_expert=tokens_per_expert,
total_num_tokens=num_tokens,
topk=topk,
num_experts=num_experts,
moe_aux_loss_coeff=coeff,
)
aux_loss_fused = fused_moe_aux_loss(
probs=probs_clone,
tokens_per_expert=tokens_per_expert,
total_num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
coeff=coeff,
)
torch.testing.assert_close(aux_loss, aux_loss_fused)
# Backward
aux_loss.backward()
aux_loss_fused.backward()
torch.testing.assert_close(probs.grad, probs_clone.grad)
def profile_topk_softmax(
dtype,
num_tokens,
num_experts,
topk,
enable_bias,
use_pre_softmax,
):
group_topk = 4
scaling_factor = 1.2
test_topk_sigmoid(
torch.float32, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias
)
test_topk_softmax(
torch.float32, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor
)
if __name__ == "__main__":
test_fused_scores_for_aux_loss(
dtype=torch.float32, num_tokens=2, num_experts=32, topk=8, score_function="softmax"
)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=256, topk=4)
...@@ -99,6 +99,9 @@ list(APPEND transformer_engine_SOURCES ...@@ -99,6 +99,9 @@ list(APPEND transformer_engine_SOURCES
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu recipe/current_scaling.cu
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu recipe/fp8_block_scaling.cu
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cooperative_groups.h>
#include <cuda_runtime.h>
#include <transformer_engine/fused_router.h>
#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
#include "common/util/cuda_runtime.h"
#include "utils.h"
namespace transformer_engine {
// Using Double to hanld all the calculations
using CompType = double;
template <typename DataType, typename IndexType>
__global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
const IndexType* tokens_per_expert,
int total_num_tokens, int num_tokens,
int num_experts, int topk, float coeff,
DataType* aux_loss, float* Const_buf) {
#if __CUDA_ARCH__ >= 900
// Using cooperative_groups to manage the cluster
namespace cg = cooperative_groups;
cg::cluster_group cluster = cg::this_cluster();
int thread_id = cg::this_grid().thread_rank();
int lane_id = thread_id % kThreadsPerWarp;
int warp_id = thread_id / kThreadsPerWarp;
int warp_num = blockDim.x * gridDim.x / kThreadsPerWarp;
// Only 1 block in the cluster
int block_id = cluster.block_rank();
int block_num = cluster.dim_blocks().x;
int cluster_id = blockIdx.x / block_num;
if (cluster_id > 0) return; // Only use the cluster 0
extern __shared__ float shmem_aux_loss[];
CompType* aggregated_probs_per_expert = reinterpret_cast<CompType*>(shmem_aux_loss);
// Clear the shmem
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
aggregated_probs_per_expert[i] = CompType(0);
}
__syncthreads();
/**
* Section: Reduce the probs to the aggregated_probs_per_expert
* 1. reduce on the block
* 2. reduce on the cluster
*/
// Loop: for all positions in each row
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
CompType tmp = CompType(0);
// Loop: for all rows that this warp is responsible for
for (int j = warp_id; j < num_tokens; j += warp_num) {
tmp += CompType(probs[j * num_experts + i]);
}
atomicAdd(&aggregated_probs_per_expert[i], tmp);
}
cluster.sync();
// The block 0 will reduce the results of all blocks
if (block_id == 0) {
for (int i = 1; i < block_num; i++) {
// Map the shared memory of the block i to the current block
CompType* dst_smem = reinterpret_cast<CompType*>(cluster.map_shared_rank(shmem_aux_loss, i));
for (int j = threadIdx.x; j < num_experts; j += blockDim.x) {
atomicAdd(&aggregated_probs_per_expert[j], dst_smem[j]);
}
}
}
cluster.sync();
/**
* Section: aggregated_probs_per_expert * tokens_per_expert
* In-place update on shmem
*/
if (block_id == 0) {
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]);
}
__syncthreads();
if (warp_id == 0) {
/**
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_experts, sum, lane_id);
__syncwarp();
if (lane_id == 0) {
/**
* Section: Compute the aux_loss
*/
float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens;
aux_loss[0] = static_cast<DataType>(static_cast<double>(intermediate_result) * C_coeff);
Const_buf[0] = C_coeff;
}
}
}
#else
// Use Only 1 block/1024 threads to avoid the grid sync
if (blockIdx.x > 0) return;
int warp_num = blockDim.x / kThreadsPerWarp;
int warp_id = threadIdx.x / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
extern __shared__ float shmem_aux_loss[];
CompType* aggregated_probs_per_expert = reinterpret_cast<CompType*>(shmem_aux_loss);
// Clear the shmem
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
aggregated_probs_per_expert[i] = CompType(0);
}
__syncthreads();
/**
* Section: Reduce the probs to the aggregated_probs_per_expert
*/
// Loop: for all positions in each row
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
CompType tmp = CompType(0);
// Loop: for all rows that this warp is responsible for
for (int j = warp_id; j < num_tokens; j += warp_num) {
tmp += CompType(probs[j * num_experts + i]);
}
atomicAdd(&aggregated_probs_per_expert[i], tmp);
}
__syncthreads();
/**
* Section: aggregated_probs_per_expert * tokens_per_expert
* In-place update on shmem
*/
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]);
}
__syncthreads();
if (warp_id == 0) {
/**
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_experts, sum, lane_id);
__syncwarp();
if (lane_id == 0) {
/**
* Section: Compute the aux_loss
*/
float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens;
aux_loss[0] = static_cast<DataType>(static_cast<double>(intermediate_result) * C_coeff);
Const_buf[0] = C_coeff;
}
}
#endif
}
template <typename DataType, typename IndexType>
void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
const IndexType* tokens_per_expert,
int total_num_tokens, int num_tokens,
int num_experts, int topk, float coeff,
DataType* aux_loss, float* Const_buf,
cudaStream_t stream) {
if (cuda::sm_arch(cuda::current_device()) >= 900) {
cudaLaunchConfig_t config = {0};
int cluster_size = 8;
config.gridDim = cluster_size;
config.blockDim = 1024;
config.dynamicSmemBytes = sizeof(CompType) * num_experts;
// Update the max cluster size based on the device
cudaOccupancyMaxPotentialClusterSize(
&cluster_size,
reinterpret_cast<void*>(fused_moe_aux_loss_forward_kernel<DataType, IndexType>), &config);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeClusterDimension;
attribute[0].val.clusterDim.x = cluster_size;
attribute[0].val.clusterDim.y = 1;
attribute[0].val.clusterDim.z = 1;
config.numAttrs = 1;
config.attrs = attribute;
cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel<DataType, IndexType>, probs,
tokens_per_expert, total_num_tokens, num_tokens, num_experts, topk, coeff,
aux_loss, Const_buf);
} else {
size_t smem_size = sizeof(CompType) * num_experts;
fused_moe_aux_loss_forward_kernel<DataType, IndexType>
<<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_tokens,
num_experts, topk, coeff, aux_loss, Const_buf);
}
}
void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert,
int total_num_tokens, int num_tokens, int num_experts, int topk,
float coeff, Tensor& aux_loss, Tensor& Const_buf,
cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
probs.data.dtype, DataType,
TE_ROUTER_INDEX_TYPE_SWITCH_ALL(
tokens_per_expert.data.dtype, IndexType,
fused_moe_aux_loss_forward_kernel_launcher<DataType, IndexType>(
reinterpret_cast<DataType*>(probs.data.dptr),
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), total_num_tokens,
num_tokens, num_experts, topk, coeff, reinterpret_cast<DataType*>(aux_loss.data.dptr),
reinterpret_cast<float*>(Const_buf.data.dptr), stream);););
}
template <typename DataType, typename IndexType>
__global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf,
const IndexType* tokens_per_expert,
int num_tokens, int num_experts,
DataType* grad_aux_loss, DataType* grad_probs) {
int global_warp_num = gridDim.x * blockDim.x / kThreadsPerWarp;
int global_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
// Loop: for all positions in each row
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
float C_coeff = Const_buf[0];
IndexType tokens_per_expert_i = tokens_per_expert[i];
double grad_aux_loss_value = static_cast<double>(grad_aux_loss[0]);
// Loop: for all rows
for (int j = global_warp_id; j < num_tokens; j += global_warp_num) {
grad_probs[j * num_experts + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value;
}
}
}
template <typename DataType, typename IndexType>
void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf,
const IndexType* tokens_per_expert, int num_tokens,
int num_experts, DataType* grad_aux_loss,
DataType* grad_probs, cudaStream_t stream) {
// Meta data for the kernel
int block_size = 256;
int grid_size = (num_tokens + block_size - 1) / block_size;
fused_moe_aux_loss_backward_kernel<DataType, IndexType><<<grid_size, block_size, 0, stream>>>(
Const_buf, tokens_per_expert, num_tokens, num_experts, grad_aux_loss, grad_probs);
}
void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert,
int num_tokens, int num_experts, Tensor& grad_aux_loss,
Tensor& grad_probs, cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
grad_aux_loss.data.dtype, DataType,
TE_ROUTER_INDEX_TYPE_SWITCH_ALL(
tokens_per_expert.data.dtype, IndexType,
fused_moe_aux_loss_backward_kernel_launcher<DataType, IndexType>(
reinterpret_cast<float*>(Const_buf.data.dptr),
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), num_tokens, num_experts,
reinterpret_cast<DataType*>(grad_aux_loss.data.dptr),
reinterpret_cast<DataType*>(grad_probs.data.dptr), stream);););
}
} // namespace transformer_engine
void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert,
int total_num_tokens, int num_tokens, int num_experts,
int topk, float coeff, NVTETensor aux_loss,
NVTETensor Const_buf, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_moe_aux_loss_forward);
using namespace transformer_engine;
fused_moe_aux_loss_forward(
*convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert), total_num_tokens,
num_tokens, num_experts, topk, coeff, *convertNVTETensorCheck(aux_loss),
*convertNVTETensorCheck(Const_buf), stream);
}
void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf,
const NVTETensor tokens_per_expert, int num_tokens,
int num_experts, NVTETensor grad_aux_loss,
NVTETensor grad_probs, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_moe_aux_loss_backward);
using namespace transformer_engine;
fused_moe_aux_loss_backward(*convertNVTETensorCheck(Const_buf),
*convertNVTETensorCheck(tokens_per_expert), num_tokens, num_experts,
*convertNVTETensorCheck(grad_aux_loss),
*convertNVTETensorCheck(grad_probs), stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cuda_runtime.h>
#include <transformer_engine/fused_router.h>
#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
#include "utils.h"
namespace transformer_engine {
template <typename DataType>
__global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens,
int num_experts, int topk,
int score_function, DataType *scores,
bool *routing_map,
DataType *intermediate_output) {
/***
* Section: Global Variables/Addresses init
* - Assume the sizeof(DataType) >= sizeof(int),
* So DataType address is assigned firstly to avoid the alignment issue
* - Each warp is responsible for one token, and has own shared memory buffer.
* Then __syncwarp() is used instead of __syncthreads()
*/
// Used variables/addresses init
int num_token_per_block = blockDim.x / kThreadsPerWarp;
int warp_id = threadIdx.x / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
extern __shared__ float shmem_scores_for_aux_loss[];
DataType *logits_buf = reinterpret_cast<DataType *>(shmem_scores_for_aux_loss);
DataType *topk_logits_buf =
reinterpret_cast<DataType *>(logits_buf + num_experts * num_token_per_block);
int *topk_indices_buf = reinterpret_cast<int *>(topk_logits_buf + topk * num_token_per_block);
// The address of buffers on the current warp
DataType *local_logits = logits_buf + warp_id * num_experts;
DataType *topk_logits = topk_logits_buf + warp_id * topk;
int *topk_indices = topk_indices_buf + warp_id * topk;
/***
* Section: Main Loop
* - Each warp is responsible for one token
*/
int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block;
for (int round = blockIdx.x; round < total_round; round += gridDim.x) {
int token_offset_cur_warp = round * num_token_per_block + warp_id;
// Each warp is responsible for one token
if (token_offset_cur_warp >= num_tokens) break;
/***
* Section: Init buffer
* - Clear the global buffer which will accept the result of this round
* - Clear/Init the shmem buffer used by current warp this round
* - Load the logits to shmem
*/
int pos_offset = token_offset_cur_warp * num_experts;
// Clear the routing_map (num_experts)
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
routing_map[pos_offset + i] = false;
if (score_function == 1) {
intermediate_output[pos_offset + i] = -std::numeric_limits<DataType>::infinity();
}
}
// Load the logits to shmem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_logits[i] = logits[pos_offset + i];
}
__threadfence_block();
__syncwarp();
/***
* Section: Preprocess
* Possible preprocess the scores before the topk operation
* - Pre-softmax
* - Sigmoid
* - Sigmoid post-processing when topk > 1
* This is in-place scores update
*/
// score_function == 1 means softmax
if (score_function == 1) {
// Apply softmax to the logits before the topk
apply_softmax_on_float(local_logits, num_experts, lane_id);
__syncwarp();
// Save the softmax output for backward
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
intermediate_output[pos_offset + i] = local_logits[i];
}
}
// score_function == 0 means sigmoid
if (score_function == 0) {
// Apply sigmoid to the logits
apply_sigmoid_on_float(local_logits, num_experts, lane_id);
__syncwarp();
// Save the sigmoid output for backward
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
intermediate_output[pos_offset + i] = local_logits[i];
}
}
__syncwarp(); //Confirm the scores is written to the softmax/sigmoid output
if (score_function == 0) {
if (topk > 1) {
auto sum_logits = warp_reduce_on_shmem(local_logits, num_experts, sum, lane_id);
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_logits[i] = static_cast<DataType>(static_cast<double>(local_logits[i]) /
(static_cast<double>(sum_logits) + epsilon));
}
}
__syncwarp();
}
/***
* Section: Topk
* Get the topk indices
*/
naive_topk_and_mask(local_logits, num_experts, topk, topk_indices, topk_logits, lane_id);
__syncwarp();
// Write the routing_map to the output tensor
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
routing_map[pos_offset + topk_indices[i]] = true;
}
// Write the scores to the output tensor
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
scores[pos_offset + i] = local_logits[i];
}
__threadfence_block();
__syncwarp();
}
}
template <typename DataType>
void fused_score_for_moe_aux_loss_forward_kernel_launcher(
const DataType *logits, int num_tokens, int num_experts, int topk, int score_function,
DataType *scores, bool *routing_map, DataType *intermediate_output, cudaStream_t stream) {
// Meta data for the kernel
size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp;
size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block;
size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // logits
+ topk * num_token_per_block * sizeof(DataType) // topk_logits
+ topk * num_token_per_block * sizeof(int); // topk_indices
fused_score_for_moe_aux_loss_forward_kernel<DataType>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, score_function, scores, routing_map,
intermediate_output);
}
void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts,
int topk, int score_function, Tensor &scores,
Tensor &routing_map, Tensor &intermediate_output,
cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
logits.data.dtype, DataType,
fused_score_for_moe_aux_loss_forward_kernel_launcher<DataType>(
reinterpret_cast<DataType *>(logits.data.dptr), num_tokens, num_experts, topk,
score_function, reinterpret_cast<DataType *>(scores.data.dptr),
reinterpret_cast<bool *>(routing_map.data.dptr),
reinterpret_cast<DataType *>(intermediate_output.data.dptr), stream););
}
template <typename DataType>
__global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *intermediate_output,
const DataType *grad_scores,
int num_tokens, int num_experts,
int topk, int score_function,
DataType *grad_logits) {
/***
* Section: Global Variables/Addresses init
* - Assume the sizeof(DataType) >= sizeof(int),
* - Each warp is responsible for one token, and has own shared memory buffer.
* Then __syncwarp() is used instead of __syncthreads()
*/
// Used variables/addresses init
int num_token_per_block = blockDim.x / kThreadsPerWarp;
int warp_id = threadIdx.x / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
extern __shared__ float shmem[];
DataType *grad_scores_buf = reinterpret_cast<DataType *>(shmem);
// To store the output of softmax/sigmoid from the fwd
DataType *act_from_fwd_buf =
reinterpret_cast<DataType *>(grad_scores_buf + num_experts * num_token_per_block);
DataType *comp_buf =
reinterpret_cast<DataType *>(act_from_fwd_buf + num_experts * num_token_per_block);
// The address of buffers on the current warp
DataType *local_grad = grad_scores_buf + warp_id * num_experts;
DataType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts;
DataType *local_comp_buf = comp_buf + warp_id * num_experts;
/***
* Section: Main Loop
* - Each warp is responsible for one token
*/
int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block;
for (int round = blockIdx.x; round < total_round; round += gridDim.x) {
int token_offset_cur_warp = round * num_token_per_block + warp_id;
// Each warp is responsible for one token
if (token_offset_cur_warp >= num_tokens) break;
/***
* Section: Init buffer
* - Clear the global buffer which will accept the result of this round
* - Clear/Init the shmem buffer used by current warp this round
* - Load the dgrad/output_from_fwd to shmem
*/
int pos_offset = token_offset_cur_warp * num_experts;
// Clear the logits_grad in global mem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
grad_logits[pos_offset + i] = 0.0f;
}
// Load the dgrad/output_from_fwd to shmem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_grad[i] = grad_scores[pos_offset + i];
local_act_from_fwd[i] = intermediate_output[pos_offset + i];
}
__threadfence_block();
__syncwarp();
/***
* Section: Backward of ops before the topk
* - Pre-softmax bwd
* - Sigmoid Post-processing bwd when topk > 1
* - Sigmoid bwd
* - Write the grad_logits to the global mem
*/
// Sigmoid Post-processing bwd when topk > 1
if (topk > 1 && score_function == 0) {
auto sum_fwd_input = warp_reduce_on_shmem(local_act_from_fwd, num_experts, sum, lane_id);
// Put the result of output * grad to the comp_buf
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i];
}
__syncwarp();
auto sum_Output_x_Grad = warp_reduce_on_shmem(local_comp_buf, num_experts, sum, lane_id);
// In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_grad[i] =
static_cast<double>(local_grad[i]) / (static_cast<double>(sum_fwd_input) + epsilon) -
static_cast<double>(sum_Output_x_Grad) /
((static_cast<double>(sum_fwd_input) + epsilon) *
(static_cast<double>(sum_fwd_input) + epsilon));
}
}
__syncwarp();
// Pre-softmax bwd
if (score_function == 1) {
apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, nullptr,
num_experts, lane_id);
__syncwarp();
}
// Sigmoid bwd
if (score_function == 0) {
apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id);
__syncwarp();
}
// Write the grad_logits to the global mem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
grad_logits[pos_offset + i] = local_grad[i];
}
__syncwarp();
}
}
template <typename DataType>
void fused_score_for_moe_aux_loss_backward_kernel_launcher(
const DataType *intermediate_output, const DataType *grad_scores, int num_tokens,
int num_experts, int topk, int score_function, DataType *grad_logits, cudaStream_t stream) {
// Meta data for the kernel
size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp;
size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block;
size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // grad_scores
+
num_experts * num_token_per_block * sizeof(DataType) // act_from_fwd
+ num_experts * num_token_per_block * sizeof(DataType); // comp_buf
fused_score_for_moe_aux_loss_backward_kernel<DataType>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function,
grad_logits);
}
void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output,
const Tensor &grad_scores, int num_tokens,
int num_experts, int topk, int score_function,
Tensor &grad_logits, cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
grad_scores.data.dtype, DataType,
fused_score_for_moe_aux_loss_backward_kernel_launcher<DataType>(
reinterpret_cast<DataType *>(intermediate_output.data.dptr),
reinterpret_cast<DataType *>(grad_scores.data.dptr), num_tokens, num_experts, topk,
score_function, reinterpret_cast<DataType *>(grad_logits.data.dptr), stream););
}
} // namespace transformer_engine
void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_tokens,
int num_experts, int topk, int score_function,
NVTETensor scores, const NVTETensor routing_map,
const NVTETensor intermediate_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_forward);
using namespace transformer_engine;
fused_score_for_moe_aux_loss_forward(*convertNVTETensorCheck(logits), num_tokens, num_experts,
topk, score_function, *convertNVTETensorCheck(scores),
*convertNVTETensorCheck(routing_map),
*convertNVTETensorCheck(intermediate_output), stream);
}
void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_output,
const NVTETensor grad_scores, int num_tokens,
int num_experts, int topk, int score_function,
NVTETensor grad_logits, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_backward);
using namespace transformer_engine;
fused_score_for_moe_aux_loss_backward(
*convertNVTETensorCheck(intermediate_output), *convertNVTETensorCheck(grad_scores),
num_tokens, num_experts, topk, score_function, *convertNVTETensorCheck(grad_logits), stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cuda_runtime.h>
#include <transformer_engine/fused_router.h>
#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
#include "utils.h"
namespace transformer_engine {
template <typename DataType, typename BiasType>
__global__ void fused_topk_with_score_function_forward_kernel(
const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax,
int num_groups, int group_topk, float scaling_factor, int score_function,
const BiasType *expert_bias, DataType *probs, bool *routing_map,
DataType *intermediate_output) {
/***
* Section: Global Variables/Addresses init
* - Assume the sizeof(DataType) >= sizeof(int),
* So DataType address is assigned firstly to avoid the alignment issue
* - Each warp is responsible for one token, and has own shared memory buffer.
* Then __syncwarp() is used instead of __syncthreads()
*/
// Used variables/addresses init
int num_token_per_block = blockDim.x / kThreadsPerWarp;
int warp_id = threadIdx.x / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
extern __shared__ float shmem[];
DataType *scores_buf = reinterpret_cast<DataType *>(shmem);
DataType *topk_scores_buf =
reinterpret_cast<DataType *>(scores_buf + num_experts * num_token_per_block);
DataType *group_scores_buf = nullptr, *masked_scores_buf = nullptr;
int *topk_indices_buf = nullptr;
if (group_topk > 0) {
masked_scores_buf = reinterpret_cast<DataType *>(topk_scores_buf + topk * num_token_per_block);
group_scores_buf =
reinterpret_cast<DataType *>(masked_scores_buf + num_experts * num_token_per_block);
topk_indices_buf = reinterpret_cast<int *>(group_scores_buf + num_groups * num_token_per_block);
} else {
topk_indices_buf = reinterpret_cast<int *>(topk_scores_buf + topk * num_token_per_block);
}
// The address of buffers on the current warp
DataType *scores = scores_buf + warp_id * num_experts;
DataType *topk_scores = topk_scores_buf + warp_id * topk;
DataType *masked_scores = masked_scores_buf + warp_id * num_experts;
DataType *group_scores = group_scores_buf + warp_id * num_groups;
int *topk_indices = topk_indices_buf + warp_id * topk;
/***
* Section: Main Loop
* - Each warp is responsible for one token
*/
int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block;
for (int round = blockIdx.x; round < total_round; round += gridDim.x) {
int token_offset_cur_warp = round * num_token_per_block + warp_id;
// Each warp is responsible for one token
if (token_offset_cur_warp >= num_tokens) break;
/***
* Section: Init buffer
* - Clear the global buffer which will accept the result of this round
* - Clear/Init the shmem buffer used by current warp this round
* - Load the logits to shmem
*/
int pos_offset = token_offset_cur_warp * num_experts;
// Clear the probs/routing_map (num_experts)
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
probs[pos_offset + i] = 0.0f;
routing_map[pos_offset + i] = false;
if (score_function == 1) {
intermediate_output[pos_offset + i] = -std::numeric_limits<DataType>::infinity();
}
}
// Load the logits to shmem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
scores[i] = logits[pos_offset + i];
}
// If group_topk > 0, init the masked_scores to -inf
if (group_topk > 0) {
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
masked_scores[i] = -std::numeric_limits<DataType>::infinity();
}
}
__threadfence_block();
__syncwarp();
/***
* Section: Preprocess
* Possible preprocess the scores before the topk operation
* - Pre-softmax
* - Sigmoid
* - Expert bias
* This is in-place scores update
*/
// score_function == 1 means softmax
if (use_pre_softmax && score_function == 1) {
// Apply softmax to the logits before the topk
apply_softmax_on_float(scores, num_experts, lane_id);
__syncwarp();
// Save the softmax output for backward
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
intermediate_output[pos_offset + i] = scores[i];
}
}
// score_function == 0 means sigmoid
if (score_function == 0) {
// Apply sigmoid to the logits
apply_sigmoid_on_float(scores, num_experts, lane_id);
__syncwarp();
// Save the sigmoid output for backward
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
intermediate_output[pos_offset + i] = scores[i];
}
}
__syncwarp(); //Confirm the scores is written to the softmax/sigmoid output
// Expert bias is only used at the sigmoid case
if (expert_bias && score_function == 0) {
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
scores[i] = static_cast<DataType>(static_cast<double>(scores[i]) +
static_cast<double>(expert_bias[i]));
}
}
__syncwarp();
/***
* Section: Topk
* Get the topk indices
* - group_topk
* - naive topk
* - topk with expert bias
*/
// Topk on the scores
// The bias is not empty only happens at the sigmod case
if (group_topk > 0) {
int group_size = num_experts / num_groups;
// Top2
for (int i = 0; i < num_groups; i++) {
naive_topk_and_mask(
/*scores ptr = */ scores + i * group_size,
/*data size = */ group_size,
/*topk = */ topk / group_topk,
/*topk indices ptr = */ topk_indices,
/*topk scores ptr = */ topk_scores,
/*lane id = */ lane_id);
__syncwarp();
// Compute the group score
if (lane_id == 0) {
DataType tmp = 0.0f;
for (int j = 0; j < topk / group_topk; j++) {
tmp = tmp + topk_scores[j];
}
group_scores[i] = tmp;
}
__syncwarp();
}
// select the topk groups
naive_topk_and_mask(
/*scores ptr = */ group_scores,
/*data size = */ num_groups,
/*topk = */ group_topk,
/*topk indices ptr = */ topk_indices,
/*topk scores ptr = */ topk_scores,
/*lane id = */ lane_id);
__syncwarp();
// Copy the unmasked scores to the buffer
for (int i = 0; i < group_topk; i++) {
int st = topk_indices[i] * group_size;
int ed = st + group_size;
for (int j = st + lane_id; j < ed; j += kThreadsPerWarp) {
masked_scores[j] = scores[j];
}
}
__syncwarp();
naive_topk_and_mask(masked_scores, num_experts, topk, topk_indices, topk_scores, lane_id);
} else {
naive_topk_and_mask(scores, num_experts, topk, topk_indices, topk_scores, lane_id);
}
__syncwarp();
/***
* Section: Postprocess
* Possible postprocess the scores after the topk operation
* - Revert Expert bias
* - Softmax
* - Sigmoid post-processing when topk > 1
* - Write the result with scaling_factor
*/
// Revert Expert bias from the topk scores
if (expert_bias && score_function == 0) {
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
topk_scores[i] =
static_cast<double>(topk_scores[i]) - static_cast<double>(expert_bias[topk_indices[i]]);
}
}
__syncwarp();
// score_function == 1 means softmax
if (!use_pre_softmax && score_function == 1) {
// Apply softmax to the topk logits
apply_softmax_on_float(topk_scores, topk, lane_id);
__syncwarp();
// Save the softmax output for backward
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
intermediate_output[pos_offset + topk_indices[i]] = topk_scores[i];
}
}
// score_function == 0 means sigmoid
if (score_function == 0) {
if (topk > 1) {
double sum_scores = warp_reduce_on_shmem(topk_scores, topk, sum, lane_id);
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
topk_scores[i] = static_cast<double>(topk_scores[i]) / (sum_scores + epsilon);
}
}
__syncwarp();
}
// Write the probs/routing_map to the output tensor
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
routing_map[pos_offset + topk_indices[i]] = true;
probs[pos_offset + topk_indices[i]] = scaling_factor * static_cast<double>(topk_scores[i]);
}
__threadfence_block();
__syncwarp();
}
}
template <typename DataType, typename BiasType>
void fused_topk_with_score_function_forward_kernel_launcher(
const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax,
int num_groups, int group_topk, float scaling_factor, int score_function,
const BiasType *expert_bias, DataType *probs, bool *routing_map, DataType *intermediate_output,
cudaStream_t stream) {
size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp;
size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block;
size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // scores
+ topk * num_token_per_block * sizeof(DataType) // topk_scores
+ topk * num_token_per_block * sizeof(int); // topk_indices
if (group_topk > 0) {
shared_memory_size += num_groups * num_token_per_block * sizeof(DataType); // group_scores
shared_memory_size += num_experts * num_token_per_block * sizeof(DataType); // maksed_scores
}
fused_topk_with_score_function_forward_kernel<DataType, BiasType>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk,
scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output);
}
void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, int num_experts,
int topk, bool use_pre_softmax, int num_groups,
int group_topk, float scaling_factor,
int score_function, const Tensor expert_bias,
Tensor probs, Tensor routing_map,
Tensor intermediate_output, cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
logits.data.dtype, DataType,
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
expert_bias.data.dtype, BiasType,
fused_topk_with_score_function_forward_kernel_launcher<DataType, BiasType>(
reinterpret_cast<DataType *>(logits.data.dptr), num_tokens, num_experts, topk,
use_pre_softmax, num_groups, group_topk, scaling_factor, score_function,
reinterpret_cast<BiasType *>(expert_bias.data.dptr),
reinterpret_cast<DataType *>(probs.data.dptr),
reinterpret_cast<bool *>(routing_map.data.dptr),
reinterpret_cast<DataType *>(intermediate_output.data.dptr), stream);););
}
template <typename DataType>
__global__ void fused_topk_with_score_function_backward_kernel(
// Inputs tensor
const bool *routing_map, const DataType *intermediate_output, const DataType *grad_probs,
// Other parameters
int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor,
int score_function,
// Output tensor
DataType *grad_logits) {
/***
* Section: Global Variables/Addresses init
* - Assume the sizeof(DataType) >= sizeof(int),
* - Each warp is responsible for one token, and has own shared memory buffer.
* Then __syncwarp() is used instead of __syncthreads()
*/
// Used variables/addresses init
int num_token_per_block = blockDim.x / kThreadsPerWarp;
int warp_id = threadIdx.x / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
extern __shared__ float shmem[];
DataType *grad_probs_buf = reinterpret_cast<DataType *>(shmem);
// To store the output of softmax/sigmoid from the fwd
DataType *act_from_fwd_buf =
reinterpret_cast<DataType *>(grad_probs_buf + num_experts * num_token_per_block);
DataType *comp_buf =
reinterpret_cast<DataType *>(act_from_fwd_buf + num_experts * num_token_per_block);
// To store the routing_map from the fwd
bool *routing_map_buf = reinterpret_cast<bool *>(comp_buf + num_experts * num_token_per_block);
// The address of buffers on the current warp
DataType *local_grad = grad_probs_buf + warp_id * num_experts;
DataType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts;
DataType *local_comp_buf = comp_buf + warp_id * num_experts;
bool *local_routing_map = routing_map_buf + warp_id * num_experts;
/***
* Section: Main Loop
* - Each warp is responsible for one token
*/
int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block;
for (int round = blockIdx.x; round < total_round; round += gridDim.x) {
int token_offset_cur_warp = round * num_token_per_block + warp_id;
// Each warp is responsible for one token
if (token_offset_cur_warp >= num_tokens) break;
/***
* Section: Init buffer
* - Clear the global buffer which will accept the result of this round
* - Clear/Init the shmem buffer used by current warp this round
* - Load the dgrad/output_from_fwd to shmem
*/
int pos_offset = token_offset_cur_warp * num_experts;
// Clear the logits_grad in global mem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
grad_logits[pos_offset + i] = 0.0f;
}
// Load the dgrad/output_from_fwd to shmem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_grad[i] = grad_probs[pos_offset + i];
local_act_from_fwd[i] = intermediate_output[pos_offset + i];
local_routing_map[i] = routing_map[pos_offset + i];
}
__threadfence_block();
__syncwarp();
/***
* Section: Backward of ops after the topk
* - Backward of the used scaling_factor
* - Sigmoid Post-processing bwd when topk > 1
* - Softmax bwd if use_pre_softmax is false
*/
// Backward of the used scaling_factor
// In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
if (local_routing_map[i]) {
local_grad[i] = static_cast<double>(local_grad[i]) * scaling_factor;
}
}
__syncwarp();
// Sigmoid Post-processing bwd when topk > 1
if (topk > 1 && score_function == 0) {
double sum_fwd_input = masked_warp_reduce_on_shmem(
/*data ptr = */ local_act_from_fwd,
/*mask ptr = */ local_routing_map,
/*data size = */ num_experts,
/*reduce func = */ sum, lane_id);
// Put the result of output * grad to the comp_buf
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_comp_buf[i] = (local_routing_map[i] ? static_cast<double>(local_grad[i]) *
static_cast<double>(local_act_from_fwd[i])
: 0.0f);
}
__syncwarp();
double sum_Output_x_Grad = masked_warp_reduce_on_shmem(
/*data ptr = */ local_comp_buf,
/*mask ptr = */ local_routing_map,
/*data size = */ num_experts,
/*reduce func = */ sum, lane_id);
// In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
if (local_routing_map[i]) {
local_grad[i] =
static_cast<double>(local_grad[i]) / (sum_fwd_input + epsilon) -
sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon));
} else {
local_grad[i] = 0.0f;
}
}
}
__syncwarp();
// Softmax bwd if use_pre_softmax is false
if (!use_pre_softmax && score_function == 1) {
apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, local_routing_map,
num_experts, lane_id);
__syncwarp();
}
/***
* Section: Backward of topk
* mask the unselected position in the grad
*/
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
if (!local_routing_map[i]) {
local_grad[i] = 0.0f;
}
}
__syncwarp();
/***
* Section: Backward of ops before the topk
* - Pre-softmax bwd
* - Sigmoid bwd
* - Write the grad_logits to the global mem
*/
// Pre-softmax bwd
if (score_function == 1 && use_pre_softmax) {
apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, nullptr,
num_experts, lane_id);
__syncwarp();
}
// Sigmoid bwd
if (score_function == 0) {
apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id);
__syncwarp();
}
// Write the grad_logits to the global mem
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
grad_logits[pos_offset + i] = local_grad[i];
}
__syncwarp();
}
}
template <typename DataType>
void fused_topk_with_score_function_backward_kernel_launcher(
const bool *routing_map, const DataType *intermediate_output, const DataType *grad_probs,
int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor,
int score_function, DataType *grad_logits, cudaStream_t stream) {
// Meta data for the kernel
size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp;
size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block;
size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // grad_probs
+
num_experts * num_token_per_block * sizeof(DataType) // act_from_fwd
+ num_experts * num_token_per_block * sizeof(DataType) // comp_buf
+ num_experts * num_token_per_block * sizeof(bool); // routing_map
fused_topk_with_score_function_backward_kernel<DataType>
<<<grid_size, kThreadsPerBlock, shared_memory_size, stream>>>(
routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk,
use_pre_softmax, scaling_factor, score_function, grad_logits);
}
void fused_topk_with_score_function_backward(const Tensor &routing_map,
const Tensor &intermediate_output,
const Tensor &grad_probs, int num_tokens,
int num_experts, int topk, bool use_pre_softmax,
float scaling_factor, int score_function,
Tensor &grad_logits, cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
grad_logits.data.dtype, DataType,
fused_topk_with_score_function_backward_kernel_launcher<DataType>(
reinterpret_cast<bool *>(routing_map.data.dptr),
reinterpret_cast<DataType *>(intermediate_output.data.dptr),
reinterpret_cast<DataType *>(grad_probs.data.dptr), num_tokens, num_experts, topk,
use_pre_softmax, scaling_factor, score_function,
reinterpret_cast<DataType *>(grad_logits.data.dptr), stream););
}
} // namespace transformer_engine
void nvte_fused_topk_with_score_function_forward(
const NVTETensor logits, int num_tokens, int num_experts, int topk, int use_pre_softmax,
int num_groups, int group_topk, float scaling_factor, int score_function,
const NVTETensor expert_bias, NVTETensor probs, NVTETensor routing_map,
NVTETensor intermediate_output, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_topk_with_score_function_forward);
using namespace transformer_engine;
fused_topk_with_score_function_forward(
*convertNVTETensorCheck(logits), num_tokens, num_experts, topk,
static_cast<bool>(use_pre_softmax), num_groups, group_topk, scaling_factor, score_function,
*convertNVTETensorCheck(expert_bias), *convertNVTETensorCheck(probs),
*convertNVTETensorCheck(routing_map), *convertNVTETensorCheck(intermediate_output), stream);
}
void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map,
const NVTETensor intermediate_output,
const NVTETensor grad_probs, int num_tokens,
int num_experts, int topk, int use_pre_softmax,
float scaling_factor, int score_function,
NVTETensor grad_logits, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_topk_with_score_function_backward);
using namespace transformer_engine;
fused_topk_with_score_function_backward(
*convertNVTETensorCheck(routing_map), *convertNVTETensorCheck(intermediate_output),
*convertNVTETensorCheck(grad_probs), num_tokens, num_experts, topk,
static_cast<bool>(use_pre_softmax), scaling_factor, score_function,
*convertNVTETensorCheck(grad_logits), stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
constexpr size_t kThreadsPerWarp = 32;
constexpr int kThreadsPerBlock =
128; // Using 4 warps in 1 CTA, Each warp is responsible for 1 token.
constexpr float epsilon = 1e-20;
template <typename T>
__device__ inline T max(T a, T b) {
return a > b ? a : b;
}
template <typename T>
__device__ inline T sum(T a, T b) {
return a + b;
}
template <typename T>
__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, T (*reduce_func)(T, T),
int lane_id) {
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile double val =
lane_id < data_size ? static_cast<double>(data_ptr[lane_id]) : static_cast<double>(0);
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
val = reduce_func(val, data_ptr[i]);
}
// Warp shuffle between threads
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1));
__syncwarp();
return T(val);
}
template <typename DataType>
__device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, int lane_id) {
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
scores[i] = static_cast<float>(1.0f / (1.0f + exp(-static_cast<float>(scores[i]))));
}
}
template <typename T>
__device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size,
T (*reduce_func)(T, T), int lane_id) {
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile double val = lane_id < data_size && mask[lane_id]
? static_cast<double>(data_ptr[lane_id])
: static_cast<double>(0);
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
if (mask[i]) {
val = reduce_func(val, data_ptr[i]);
}
}
// Warp shuffle between threads
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 16));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 8));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 4));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 2));
val = reduce_func(val, __shfl_xor_sync(0xffffffff, val, 1));
__syncwarp();
return T(val);
}
template <typename DataType>
__device__ inline void apply_sigmoid_bwd_on_float(DataType *grad, DataType *fwd_output,
int data_size, int lane_id) {
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
grad[i] = static_cast<double>(grad[i]) * static_cast<double>(fwd_output[i]) *
(1 - static_cast<double>(fwd_output[i]));
}
}
template <typename DataType>
__device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_output,
DataType *comp_buf, bool *mask, int data_size,
int lane_id) {
// Put the result of output * grad to the comp_buf
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
if (mask) {
if (mask[i])
comp_buf[i] = static_cast<float>(grad[i]) * static_cast<float>(fwd_output[i]);
else
comp_buf[i] = 0.0f;
} else {
comp_buf[i] = static_cast<float>(grad[i]) * static_cast<float>(fwd_output[i]);
}
}
__syncwarp();
float sum_Output_x_Grad = warp_reduce_on_shmem(
/*data ptr = */ comp_buf,
/*data size = */ data_size,
/*reduce func = */ sum, lane_id);
// In-place update
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
if (mask) {
if (mask[i])
grad[i] =
static_cast<float>(fwd_output[i]) * (static_cast<float>(grad[i]) - sum_Output_x_Grad);
else
grad[i] = 0.0f;
} else {
grad[i] =
static_cast<float>(fwd_output[i]) * (static_cast<float>(grad[i]) - sum_Output_x_Grad);
}
}
}
template <typename DataType>
__device__ inline void apply_softmax_on_float(DataType *scores, int data_size, int lane_id) {
// 1. compute the max of value
float max_val = static_cast<float>(warp_reduce_on_shmem(scores, data_size, max, lane_id));
// 2. value -> exp_value
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
scores[i] = static_cast<float>(exp(static_cast<float>(scores[i]) - max_val));
}
__syncwarp();
// 3. compute the sum of exp_value
float sum_val = static_cast<float>(warp_reduce_on_shmem(scores, data_size, sum, lane_id));
// 4. update the softmax value
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
scores[i] = static_cast<float>(scores[i]) / sum_val;
}
__syncwarp();
}
template <typename T>
__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices,
T *topk_scores, int lane_id) {
// Topk Times: Find the max value and its index
// Then mask it, and record the index in the topk_indices
// After looping topk times, the topk_indices will be the topk indices
for (int k = 0; k < topk; k++) {
// Find the max value and its index
volatile double val =
(lane_id < data_size) ? static_cast<double>(scores[lane_id]) : static_cast<double>(0);
volatile int index = (lane_id < data_size) ? lane_id : 0;
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
volatile double cur_val = scores[i];
if (cur_val > val) {
val = cur_val;
index = i;
}
}
// Warp shuffle between threads
for (int s = 16; s > 0; s /= 2) {
volatile auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
volatile auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
if (shuffled_val > val) {
val = shuffled_val;
index = shuffled_index;
}
}
if (lane_id == 0) {
topk_indices[k] = index;
topk_scores[k] = val;
scores[index] =
static_cast<double>(-1.0) - val; // make the selected experts using val = - 1 - val
}
__syncwarp();
}
// Reset the scores to the original value
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
scores[topk_indices[i]] =
static_cast<double>(-1.0) - static_cast<double>(scores[topk_indices[i]]);
}
}
// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kInt32: { \
using type = int32_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt64: { \
using type = int64_t; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
} // namespace transformer_engine
#endif
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_H_
#define TRANSFORMER_ENGINE_FUSED_ROUTER_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Apply topk + softmax/sigmoid to the input tensor. Grouped topk is supported.
*
* \param[in] logits Logits from the gating GEMM.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] use_pre_softmax Whether to use softmax before topk.
* \param[in] num_groups Number of groups in grouped topk.
* \param[in] group_topk Grouped topk value.
* \param[in] scaling_factor Scaling factor.
* \param[in] score_function Score function, 0: sigmoid, 1: softmax.
* \param[in] expert_bias Expert bias. (Only used at the sigmoid case)
* \param[out] probs Output tensor for probabilities.
* \param[out] routing_map Output tensor for routing map.
* \param[out] intermediate_output Output tensor for intermediate output. (Softmax/sigmoid output)
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_topk_with_score_function_forward(
const NVTETensor logits, int num_tokens, int num_experts, int topk, int use_pre_softmax,
int num_groups, int group_topk, float scaling_factor, int score_function,
const NVTETensor expert_bias, NVTETensor probs, NVTETensor routing_map,
NVTETensor intermediate_output, cudaStream_t stream);
/*! \brief Backward pass for fused topk + softmax/sigmoid.
*
* \param[in] routing_map Routing map.
* \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output)
* \param[in] grad_probs Gradient of probs.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] use_pre_softmax Whether to use softmax before topk.
* \param[in] scaling_factor Scaling factor.
* \param[in] score_function Score function, 0: sigmoid, 1: softmax.
* \param[out] grad_logits Gradient of logits.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map,
const NVTETensor intermediate_output,
const NVTETensor grad_probs, int num_tokens,
int num_experts, int topk, int use_pre_softmax,
float scaling_factor, int score_function,
NVTETensor grad_logits, cudaStream_t stream);
/*! \brief Forward pass for computing scores/routing map for auxiliary loss.
*
* \param[in] logits Logits from the gating GEMM.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] score_function Score function, 0: sigmoid, 1: softmax.
* \param[out] scores Output tensor for scores.
* \param[in] routing_map Routing map.
* \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output)
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_tokens,
int num_experts, int topk, int score_function,
NVTETensor scores, const NVTETensor routing_map,
const NVTETensor intermediate_output,
cudaStream_t stream);
/*! \brief Backward pass for computing scores/routing map for auxiliary loss.
*
* \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output)
* \param[in] grad_scores Gradient of scores.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] score_function Score function, 0: sigmoid, 1: softmax.
* \param[out] grad_logits Gradient of logits.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_output,
const NVTETensor grad_scores, int num_tokens,
int num_experts, int topk, int score_function,
NVTETensor grad_logits, cudaStream_t stream);
/*! \brief Forward pass for auxiliary loss.
*
* \param[in] probs Probabilities from the forward pass.
* \param[in] tokens_per_expert Number of tokens per expert.
* \param[in] total_num_tokens Number of total tokens. Will be used in seq/global aux loss.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] topk Topk value.
* \param[in] coeff Coefficient.
* \param[out] aux_loss Output GPU scalar for auxiliary loss.
* \param[out] Const_buf Output GPU scalar for temporary constant buffer for backward pass.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert,
int total_num_tokens, int num_tokens, int num_experts,
int topk, float coeff, NVTETensor aux_loss,
NVTETensor Const_buf, cudaStream_t stream);
/*! \brief Backward pass for auxiliary loss.
*
* \param[in] Const_buf Constant buffer from the forward pass.
* \param[in] tokens_per_expert Number of tokens per expert.
* \param[in] num_tokens Number of total tokens.
* \param[in] num_experts Number of experts.
* \param[in] grad_aux_loss Gradient of auxiliary loss.
* \param[out] grad_probs Gradient of probs.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf,
const NVTETensor tokens_per_expert, int num_tokens,
int num_experts, NVTETensor grad_aux_loss,
NVTETensor grad_probs, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_FUSED_ROPE_H_
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <transformer_engine/comm_gemm_overlap.h> #include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/fused_attn.h> #include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h> #include <transformer_engine/fused_rope.h>
#include <transformer_engine/fused_router.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/multi_stream.h> #include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_tensor.h> #include <transformer_engine/multi_tensor.h>
......
...@@ -13,6 +13,37 @@ ...@@ -13,6 +13,37 @@
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
/***************************************************************************************************
* Router fusion
**************************************************************************************************/
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fwd(
at::Tensor logits, int topk, bool use_pre_softmax, c10::optional<int> num_groups,
c10::optional<int> group_topk, c10::optional<float> scaling_factor, std::string score_function,
c10::optional<at::Tensor> expert_bias);
at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts,
at::Tensor routing_map,
at::Tensor intermediate_output, at::Tensor grad_probs,
int topk, bool use_pre_softmax,
c10::optional<float> scaling_factor,
std::string score_function);
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_score_for_moe_aux_loss_fwd(
at::Tensor logits, int topk, std::string score_function);
at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts,
at::Tensor intermediate_output, at::Tensor grad_probs,
int topk, std::string score_function);
std::tuple<at::Tensor, at::Tensor> fused_moe_aux_loss_fwd(at::Tensor probs,
at::Tensor tokens_per_expert,
int total_num_tokens, int num_tokens,
int num_experts, int topk, float coeff);
at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert,
int num_tokens, int num_experts, at::Tensor grad_aux_loss);
/*************************************************************************************************** /***************************************************************************************************
* Permutation * Permutation
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -258,6 +258,32 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -258,6 +258,32 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward, m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward,
"Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>()); "Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>());
// fused router
m.def("fused_topk_with_score_function_fwd",
&transformer_engine::pytorch::fused_topk_with_score_function_fwd, py::arg("logits"),
py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"), py::arg("group_topk"),
py::arg("scaling_factor"), py::arg("score_function"), py::arg("expert_bias"),
"Fused topk softmax fwd");
m.def("fused_topk_with_score_function_bwd",
&transformer_engine::pytorch::fused_topk_with_score_function_bwd, py::arg("num_tokens"),
py::arg("num_experts"), py::arg("routing_map"), py::arg("intermediate_output"),
py::arg("grad_probs"), py::arg("topk"), py::arg("use_pre_softmax"),
py::arg("scaling_factor"), py::arg("score_function"), "Fused topk softmax bwd");
m.def("fused_score_for_moe_aux_loss_fwd",
&transformer_engine::pytorch::fused_score_for_moe_aux_loss_fwd, py::arg("logits"),
py::arg("topk"), py::arg("score_function"), "Fused topk softmax fwd");
m.def("fused_score_for_moe_aux_loss_bwd",
&transformer_engine::pytorch::fused_score_for_moe_aux_loss_bwd, py::arg("num_tokens"),
py::arg("num_experts"), py::arg("intermediate_output"), py::arg("grad_scores"),
py::arg("topk"), py::arg("score_function"), "Fused topk softmax bwd");
m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd,
py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"),
py::arg("num_tokens"), py::arg("num_experts"), py::arg("topk"), py::arg("coeff"),
"Fused aux loss fwd");
m.def("fused_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_moe_aux_loss_bwd,
py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_tokens"),
py::arg("num_experts"), py::arg("grad_aux_loss"), "Fused aux loss bwd");
// Misc // Misc
m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version, m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version,
"Get cublasLt version", py::call_guard<py::gil_scoped_release>()); "Get cublasLt version", py::call_guard<py::gil_scoped_release>());
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "common.h"
namespace transformer_engine::pytorch {
static std::map<std::string, int> score_function_map = {{"sigmoid", 0}, {"softmax", 1}};
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fwd(
at::Tensor logits, int topk, bool use_pre_softmax, c10::optional<int> num_groups,
c10::optional<int> group_topk, c10::optional<float> scaling_factor, std::string score_function,
c10::optional<at::Tensor> expert_bias) {
int num_tokens = logits.size(0);
int num_experts = logits.size(1);
// Check if the input is valid
TORCH_CHECK(num_tokens > 0 && num_experts > 0,
"num_tokens and num_experts must be greater than 0");
// Expert bias only happens at the sigmoid case
if (expert_bias.has_value()) {
TORCH_CHECK(score_function == "sigmoid",
"score_function must be sigmoid when expert_bias is not None");
}
// Check if the score function is valid
TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid",
"score_function must be softmax or sigmoid for router fusion");
if (score_function == "sigmoid") {
use_pre_softmax = false; // Pre-softmax only happens at the softmax case
}
// Reformat the input to make it compatible with the kernel
int group_topk_value = group_topk.has_value() ? group_topk.value() : -1;
int num_groups_value = num_groups.has_value() ? num_groups.value() : -1;
float scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f;
// Construct the output tensor
at::Tensor probs =
at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA));
at::Tensor routing_map =
at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA));
// Intermediate output is used to store the output of the softmax/sigmoid function
at::Tensor intermediate_output =
at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA));
auto logits_cu = makeTransformerEngineTensor(logits);
auto probs_cu = makeTransformerEngineTensor(probs);
auto routing_map_cu = makeTransformerEngineTensor(routing_map);
auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output);
auto expert_bias_cu = TensorWrapper(); // empty expert_bias_cu tensor
if (expert_bias.has_value()) {
expert_bias_cu = makeTransformerEngineTensor(expert_bias.value());
}
nvte_fused_topk_with_score_function_forward(
logits_cu.data(), num_tokens, num_experts, topk, use_pre_softmax, num_groups_value,
group_topk_value, scaling_factor_value, score_function_map[score_function],
expert_bias_cu.data(), probs_cu.data(), routing_map_cu.data(), intermediate_output_cu.data(),
at::cuda::getCurrentCUDAStream());
return std::make_tuple(probs, routing_map, intermediate_output);
}
at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts,
at::Tensor routing_map,
at::Tensor intermediate_output, at::Tensor grad_probs,
int topk, bool use_pre_softmax,
c10::optional<float> scaling_factor,
std::string score_function) {
// Get the value of the parameters
auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f;
auto score_function_value = score_function_map[score_function];
// Init the output tensor
at::Tensor grad_logits = at::empty(
{num_tokens, num_experts}, at::dtype(intermediate_output.scalar_type()).device(at::kCUDA));
auto routing_map_cu = makeTransformerEngineTensor(routing_map);
auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output);
auto grad_probs_cu = makeTransformerEngineTensor(grad_probs);
auto grad_logits_cu = makeTransformerEngineTensor(grad_logits);
nvte_fused_topk_with_score_function_backward(
routing_map_cu.data(), intermediate_output_cu.data(), grad_probs_cu.data(), num_tokens,
num_experts, topk, use_pre_softmax, scaling_factor_value, score_function_value,
grad_logits_cu.data(), at::cuda::getCurrentCUDAStream());
return grad_logits;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_score_for_moe_aux_loss_fwd(
at::Tensor logits, int topk, std::string score_function) {
int num_tokens = logits.size(0);
int num_experts = logits.size(1);
// Check if the input is valid
TORCH_CHECK(num_tokens > 0 && num_experts > 0,
"num_tokens and num_experts must be greater than 0");
TORCH_CHECK(topk > 0, "topk must be greater than 0");
// Check if the score function is valid
TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid",
"score_function must be softmax or sigmoid for router fusion");
int score_function_value = score_function_map[score_function];
// Construct the output tensor
at::Tensor scores =
at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA));
at::Tensor routing_map =
at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA));
at::Tensor intermediate_output =
at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA));
auto logits_cu = makeTransformerEngineTensor(logits);
auto scores_cu = makeTransformerEngineTensor(scores);
auto routing_map_cu = makeTransformerEngineTensor(routing_map);
auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output);
nvte_fused_score_for_moe_aux_loss_forward(
logits_cu.data(), num_tokens, num_experts, topk, score_function_value, scores_cu.data(),
routing_map_cu.data(), intermediate_output_cu.data(), at::cuda::getCurrentCUDAStream());
return std::make_tuple(scores, routing_map, intermediate_output);
}
at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts,
at::Tensor intermediate_output, at::Tensor grad_scores,
int topk, std::string score_function) {
// Get the value of the parameters
int score_function_value = score_function_map[score_function];
// Init the output tensor
at::Tensor grad_logits = at::empty(
{num_tokens, num_experts}, at::dtype(intermediate_output.scalar_type()).device(at::kCUDA));
auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output);
auto grad_scores_cu = makeTransformerEngineTensor(grad_scores);
auto grad_logits_cu = makeTransformerEngineTensor(grad_logits);
nvte_fused_score_for_moe_aux_loss_backward(
intermediate_output_cu.data(), grad_scores_cu.data(), num_tokens, num_experts, topk,
score_function_value, grad_logits_cu.data(), at::cuda::getCurrentCUDAStream());
return grad_logits;
}
std::tuple<at::Tensor, at::Tensor> fused_moe_aux_loss_fwd(at::Tensor probs,
at::Tensor tokens_per_expert,
int total_num_tokens, int num_tokens,
int num_experts, int topk, float coeff) {
TORCH_CHECK(topk > 0, "topk must be greater than 0");
TORCH_CHECK(total_num_tokens > 0, "total_num_tokens must be greater than 0");
TORCH_CHECK(num_experts > 0, "num_experts must be greater than 0");
// Create the output tensor
at::Tensor aux_loss = at::empty({}, at::dtype(probs.scalar_type()).device(at::kCUDA));
at::Tensor Const_buf = at::empty({}, at::dtype(at::kFloat).device(at::kCUDA));
auto probs_cu = makeTransformerEngineTensor(probs);
auto tokens_per_expert_cu = makeTransformerEngineTensor(tokens_per_expert);
auto aux_loss_cu = makeTransformerEngineTensor(aux_loss);
auto Const_buf_cu = makeTransformerEngineTensor(Const_buf);
nvte_fused_moe_aux_loss_forward(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens,
num_tokens, num_experts, topk, coeff, aux_loss_cu.data(),
Const_buf_cu.data(), at::cuda::getCurrentCUDAStream());
return std::make_tuple(aux_loss, Const_buf);
}
at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert,
int num_tokens, int num_experts, at::Tensor grad_aux_loss) {
// Create the output tensor
at::Tensor grad_probs = at::empty({num_tokens, num_experts},
at::dtype(grad_aux_loss.scalar_type()).device(at::kCUDA));
auto Const_buf_cu = makeTransformerEngineTensor(Const_buf);
auto tokens_per_expert_cu = makeTransformerEngineTensor(tokens_per_expert);
auto grad_aux_loss_cu = makeTransformerEngineTensor(grad_aux_loss);
auto grad_probs_cu = makeTransformerEngineTensor(grad_probs);
// Meta data for the kernel
nvte_fused_moe_aux_loss_backward(Const_buf_cu.data(), tokens_per_expert_cu.data(), num_tokens,
num_experts, grad_aux_loss_cu.data(), grad_probs_cu.data(),
at::cuda::getCurrentCUDAStream());
return grad_probs;
}
} // namespace transformer_engine::pytorch
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Fused functions used in the MoE router
"""
import torch
import transformer_engine_torch as tex
class FusedTopkScoreFunction(torch.autograd.Function):
"""
Fused Topk with Score Function router.
Currently, only support softmax and sigmoid.
"""
@staticmethod
def forward(
ctx,
logits: torch.Tensor,
topk: int,
use_pre_softmax: bool,
num_groups: int,
group_topk: int,
scaling_factor: float,
score_function: str,
expert_bias: torch.Tensor,
):
# pylint: disable=missing-function-docstring
probs, routing_map, intermediate_output = tex.fused_topk_with_score_function_fwd(
logits,
topk,
use_pre_softmax,
num_groups,
group_topk,
scaling_factor,
score_function,
expert_bias,
)
ctx.save_for_backward(routing_map, intermediate_output)
ctx.num_tokens = logits.size(0)
ctx.num_experts = logits.size(1)
ctx.use_pre_softmax = use_pre_softmax
ctx.topk = topk
ctx.scaling_factor = scaling_factor
ctx.score_function = score_function
return probs, routing_map
@staticmethod
def backward(ctx, grad_probs, _):
# pylint: disable=missing-function-docstring
routing_map, intermediate_output = ctx.saved_tensors
grad_logits = tex.fused_topk_with_score_function_bwd(
ctx.num_tokens,
ctx.num_experts,
routing_map,
intermediate_output,
grad_probs.contiguous(),
ctx.topk,
ctx.use_pre_softmax,
ctx.scaling_factor,
ctx.score_function,
)
return grad_logits, None, None, None, None, None, None, None
def fused_topk_with_score_function(
logits: torch.Tensor,
topk: int,
use_pre_softmax: bool,
num_groups: int,
group_topk: int,
scaling_factor: float,
score_function: str,
expert_bias: torch.Tensor,
):
"""
Fused topk with score function router.
Parameters
----------
logits: torch.Tensor
topk: int
use_pre_softmax: bool
if enabled, the computation order: softmax -> topk
num_groups: int
used in the group topk
group_topk: int
used in the group topk
scaling_factor: float
score_function: str
currently only support softmax and sigmoid
expert_bias: torch.Tensor
could be used in the sigmoid
Returns
-------
probs: torch.Tensor
routing_map: torch.Tensor
"""
if logits.dtype == torch.float64:
raise ValueError("Current TE does not support float64 router type")
return FusedTopkScoreFunction.apply(
logits,
topk,
use_pre_softmax,
num_groups,
group_topk,
scaling_factor,
score_function,
expert_bias,
)
class FusedComputeScoresForMoEAuxLoss(torch.autograd.Function):
"""
Fused compute scores for MoE aux loss.
"""
@staticmethod
def forward(
ctx,
logits: torch.Tensor,
topk: int,
score_function: str,
):
# pylint: disable=missing-function-docstring
scores, routing_map, intermediate_output = tex.fused_score_for_moe_aux_loss_fwd(
logits=logits,
topk=topk,
score_function=score_function,
)
ctx.save_for_backward(intermediate_output)
ctx.topk = topk
ctx.score_function = score_function
ctx.num_tokens = logits.size(0)
ctx.num_experts = logits.size(1)
return routing_map, scores
@staticmethod
def backward(ctx, _, grad_scores):
# pylint: disable=missing-function-docstring
intermediate_output = ctx.saved_tensors[0]
grad_logits = tex.fused_score_for_moe_aux_loss_bwd(
num_tokens=ctx.num_tokens,
num_experts=ctx.num_experts,
intermediate_output=intermediate_output,
grad_scores=grad_scores.contiguous(),
topk=ctx.topk,
score_function=ctx.score_function,
)
return grad_logits, None, None
def fused_compute_score_for_moe_aux_loss(
logits: torch.Tensor,
topk: int,
score_function: str,
):
"""
Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function.
Parameters
----------
logits: torch.Tensor
topk: int
score_function: str
currently only support softmax and sigmoid
Returns
-------
routing_map: torch.Tensor
scores: torch.Tensor
"""
return FusedComputeScoresForMoEAuxLoss.apply(logits, topk, score_function)
class FusedAuxLoss(torch.autograd.Function):
"""
Fused MoE aux loss.
"""
@staticmethod
def forward(
ctx,
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
total_num_tokens: int,
num_experts: int,
topk: int,
coeff: float,
):
# pylint: disable=missing-function-docstring
num_tokens = probs.size(0)
aux_loss, Const_buf = tex.fused_moe_aux_loss_fwd(
probs=probs,
tokens_per_expert=tokens_per_expert,
total_num_tokens=total_num_tokens,
num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
coeff=coeff,
)
ctx.save_for_backward(Const_buf, tokens_per_expert)
ctx.num_tokens = num_tokens
ctx.num_experts = num_experts
return aux_loss
@staticmethod
def backward(ctx, grad_aux_loss):
# pylint: disable=missing-function-docstring
Const_buf, tokens_per_expert = ctx.saved_tensors
grad_probs = tex.fused_moe_aux_loss_bwd(
Const_buf=Const_buf,
tokens_per_expert=tokens_per_expert,
num_tokens=ctx.num_tokens,
num_experts=ctx.num_experts,
grad_aux_loss=grad_aux_loss,
)
return grad_probs, None, None, None, None, None
def fused_moe_aux_loss(
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
total_num_tokens: int,
num_experts: int,
topk: int,
coeff: float,
):
"""
Fused MoE aux loss.
Parameters
----------
probs: torch.Tensor
tokens_per_expert: torch.Tensor
the number of tokens per expert
total_num_tokens: int
the total number of tokens, involved in the aux loss calculation
num_experts: int
topk: int
coeff: float
the coefficient of the aux loss
Returns
-------
aux_loss: torch.scalar
"""
return FusedAuxLoss.apply(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment