Unverified Commit 397c4be6 authored by Autumn1998's avatar Autumn1998 Committed by GitHub
Browse files

[PyTorch] Fix bugs in router fusion (#1944)



* fix underterminsic problem in CI
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* fix bug on mbs>1
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* fix bug on sm dispatcher
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CI initial values
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent dc97cc9e
......@@ -148,14 +148,17 @@ def run_comparison(
):
# Set some parameters
if score_function == "sigmoid":
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 0.1
logits = logits.unsqueeze(0).repeat(num_tokens, 1)
# Construct the special logits to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = logits.view(num_tokens, num_experts)
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
if enable_bias and score_function == "sigmoid":
expert_bias = torch.arange(num_experts, device="cuda") * 0.1
expert_bias = torch.flip(expert_bias, dims=[0])
expert_bias.requires_grad = True
else:
expert_bias = None
......@@ -210,7 +213,7 @@ def run_comparison(
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("group_topk", [None, 4])
......@@ -241,7 +244,7 @@ def test_topk_sigmoid(
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("use_pre_softmax", [True, False])
......@@ -272,12 +275,19 @@ def test_topk_softmax(
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"])
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
logits_clone = deepcopy(logits)
......@@ -307,11 +317,15 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4])
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
probs = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
# Construct the special probs to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
probs = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
probs = probs.view(num_tokens, num_experts)
probs.requires_grad = True
tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32)
......@@ -375,6 +389,6 @@ if __name__ == "__main__":
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=256, topk=4)
......@@ -23,8 +23,8 @@ using CompType = double;
template <typename DataType, typename IndexType>
__global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
const IndexType* tokens_per_expert,
int total_num_tokens, int num_tokens,
int num_experts, int topk, float coeff,
int total_num_tokens, int num_experts,
int num_rows, int num_cols, int topk, float coeff,
DataType* aux_loss, float* Const_buf) {
#if __CUDA_ARCH__ >= 900
// Using cooperative_groups to manage the cluster
......@@ -43,7 +43,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
extern __shared__ float shmem_aux_loss[];
CompType* aggregated_probs_per_expert = reinterpret_cast<CompType*>(shmem_aux_loss);
// Clear the shmem
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
aggregated_probs_per_expert[i] = CompType(0);
}
__syncthreads();
......@@ -54,11 +54,11 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* 2. reduce on the cluster
*/
// Loop: for all positions in each row
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) {
CompType tmp = CompType(0);
// Loop: for all rows that this warp is responsible for
for (int j = warp_id; j < num_tokens; j += warp_num) {
tmp += CompType(probs[j * num_experts + i]);
for (int j = warp_id; j < num_rows; j += warp_num) {
tmp += CompType(probs[j * num_cols + i]);
}
atomicAdd(&aggregated_probs_per_expert[i], tmp);
}
......@@ -68,7 +68,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
for (int i = 1; i < block_num; i++) {
// Map the shared memory of the block i to the current block
CompType* dst_smem = reinterpret_cast<CompType*>(cluster.map_shared_rank(shmem_aux_loss, i));
for (int j = threadIdx.x; j < num_experts; j += blockDim.x) {
for (int j = threadIdx.x; j < num_cols; j += blockDim.x) {
atomicAdd(&aggregated_probs_per_expert[j], dst_smem[j]);
}
}
......@@ -80,7 +80,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* In-place update on shmem
*/
if (block_id == 0) {
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]);
}
__syncthreads();
......@@ -90,7 +90,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_experts, sum, lane_id);
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id);
__syncwarp();
if (lane_id == 0) {
......@@ -113,7 +113,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
CompType* aggregated_probs_per_expert = reinterpret_cast<CompType*>(shmem_aux_loss);
// Clear the shmem
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
aggregated_probs_per_expert[i] = CompType(0);
}
__syncthreads();
......@@ -122,11 +122,11 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce the probs to the aggregated_probs_per_expert
*/
// Loop: for all positions in each row
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) {
CompType tmp = CompType(0);
// Loop: for all rows that this warp is responsible for
for (int j = warp_id; j < num_tokens; j += warp_num) {
tmp += CompType(probs[j * num_experts + i]);
for (int j = warp_id; j < num_rows; j += warp_num) {
tmp += CompType(probs[j * num_cols + i]);
}
atomicAdd(&aggregated_probs_per_expert[i], tmp);
}
......@@ -136,7 +136,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: aggregated_probs_per_expert * tokens_per_expert
* In-place update on shmem
*/
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]);
}
__syncthreads();
......@@ -146,7 +146,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_experts, sum, lane_id);
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id);
__syncwarp();
if (lane_id == 0) {
......@@ -164,16 +164,17 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
template <typename DataType, typename IndexType>
void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
const IndexType* tokens_per_expert,
int total_num_tokens, int num_tokens,
int num_experts, int topk, float coeff,
int total_num_tokens, int num_experts, int num_rows,
int num_cols, int topk, float coeff,
DataType* aux_loss, float* Const_buf,
cudaStream_t stream) {
if (cuda::sm_arch(cuda::current_device()) >= 900) {
if (cuda::sm_arch(cuda::current_device()) >= 90) {
cudaLaunchConfig_t config = {0};
int cluster_size = 8;
config.gridDim = cluster_size;
config.blockDim = 1024;
config.dynamicSmemBytes = sizeof(CompType) * num_experts;
config.dynamicSmemBytes = sizeof(CompType) * num_cols;
config.stream = stream;
// Update the max cluster size based on the device
cudaOccupancyMaxPotentialClusterSize(
......@@ -189,19 +190,19 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
config.attrs = attribute;
cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel<DataType, IndexType>, probs,
tokens_per_expert, total_num_tokens, num_tokens, num_experts, topk, coeff,
aux_loss, Const_buf);
tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk,
coeff, aux_loss, Const_buf);
} else {
size_t smem_size = sizeof(CompType) * num_experts;
size_t smem_size = sizeof(CompType) * num_cols;
fused_moe_aux_loss_forward_kernel<DataType, IndexType>
<<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_tokens,
num_experts, topk, coeff, aux_loss, Const_buf);
<<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts,
num_rows, num_cols, topk, coeff, aux_loss, Const_buf);
}
}
void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert,
int total_num_tokens, int num_tokens, int num_experts, int topk,
float coeff, Tensor& aux_loss, Tensor& Const_buf,
int total_num_tokens, int num_experts, int num_rows, int num_cols,
int topk, float coeff, Tensor& aux_loss, Tensor& Const_buf,
cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
probs.data.dtype, DataType,
......@@ -210,45 +211,46 @@ void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_ex
fused_moe_aux_loss_forward_kernel_launcher<DataType, IndexType>(
reinterpret_cast<DataType*>(probs.data.dptr),
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), total_num_tokens,
num_tokens, num_experts, topk, coeff, reinterpret_cast<DataType*>(aux_loss.data.dptr),
num_experts, num_rows, num_cols, topk, coeff,
reinterpret_cast<DataType*>(aux_loss.data.dptr),
reinterpret_cast<float*>(Const_buf.data.dptr), stream);););
}
template <typename DataType, typename IndexType>
__global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf,
const IndexType* tokens_per_expert,
int num_tokens, int num_experts,
DataType* grad_aux_loss, DataType* grad_probs) {
const IndexType* tokens_per_expert, int num_rows,
int num_cols, DataType* grad_aux_loss,
DataType* grad_probs) {
int global_warp_num = gridDim.x * blockDim.x / kThreadsPerWarp;
int global_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;
// Loop: for all positions in each row
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) {
float C_coeff = Const_buf[0];
IndexType tokens_per_expert_i = tokens_per_expert[i];
double grad_aux_loss_value = static_cast<double>(grad_aux_loss[0]);
// Loop: for all rows
for (int j = global_warp_id; j < num_tokens; j += global_warp_num) {
grad_probs[j * num_experts + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value;
for (int j = global_warp_id; j < num_rows; j += global_warp_num) {
grad_probs[j * num_cols + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value;
}
}
}
template <typename DataType, typename IndexType>
void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf,
const IndexType* tokens_per_expert, int num_tokens,
int num_experts, DataType* grad_aux_loss,
const IndexType* tokens_per_expert, int num_rows,
int num_cols, DataType* grad_aux_loss,
DataType* grad_probs, cudaStream_t stream) {
// Meta data for the kernel
int block_size = 256;
int grid_size = (num_tokens + block_size - 1) / block_size;
int grid_size = (num_rows + block_size - 1) / block_size;
fused_moe_aux_loss_backward_kernel<DataType, IndexType><<<grid_size, block_size, 0, stream>>>(
Const_buf, tokens_per_expert, num_tokens, num_experts, grad_aux_loss, grad_probs);
Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs);
}
void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert,
int num_tokens, int num_experts, Tensor& grad_aux_loss,
int num_rows, int num_cols, Tensor& grad_aux_loss,
Tensor& grad_probs, cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
grad_aux_loss.data.dtype, DataType,
......@@ -256,7 +258,7 @@ void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_p
tokens_per_expert.data.dtype, IndexType,
fused_moe_aux_loss_backward_kernel_launcher<DataType, IndexType>(
reinterpret_cast<float*>(Const_buf.data.dptr),
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), num_tokens, num_experts,
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), num_rows, num_cols,
reinterpret_cast<DataType*>(grad_aux_loss.data.dptr),
reinterpret_cast<DataType*>(grad_probs.data.dptr), stream);););
}
......@@ -264,25 +266,25 @@ void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_p
} // namespace transformer_engine
void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert,
int total_num_tokens, int num_tokens, int num_experts,
int topk, float coeff, NVTETensor aux_loss,
int total_num_tokens, int num_experts, int num_rows,
int num_cols, int topk, float coeff, NVTETensor aux_loss,
NVTETensor Const_buf, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_moe_aux_loss_forward);
using namespace transformer_engine;
fused_moe_aux_loss_forward(
*convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert), total_num_tokens,
num_tokens, num_experts, topk, coeff, *convertNVTETensorCheck(aux_loss),
num_experts, num_rows, num_cols, topk, coeff, *convertNVTETensorCheck(aux_loss),
*convertNVTETensorCheck(Const_buf), stream);
}
void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf,
const NVTETensor tokens_per_expert, int num_tokens,
int num_experts, NVTETensor grad_aux_loss,
NVTETensor grad_probs, cudaStream_t stream) {
const NVTETensor tokens_per_expert, int num_rows,
int num_cols, NVTETensor grad_aux_loss, NVTETensor grad_probs,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_moe_aux_loss_backward);
using namespace transformer_engine;
fused_moe_aux_loss_backward(*convertNVTETensorCheck(Const_buf),
*convertNVTETensorCheck(tokens_per_expert), num_tokens, num_experts,
*convertNVTETensorCheck(tokens_per_expert), num_rows, num_cols,
*convertNVTETensorCheck(grad_aux_loss),
*convertNVTETensorCheck(grad_probs), stream);
}
......@@ -96,8 +96,9 @@ void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_ou
* \param[in] probs Probabilities from the forward pass.
* \param[in] tokens_per_expert Number of tokens per expert.
* \param[in] total_num_tokens Number of total tokens. Will be used in seq/global aux loss.
* \param[in] num_tokens Number of tokens.
* \param[in] num_experts Number of experts.
* \param[in] num_rows Number of rows of probs.
* \param[in] num_cols Number of columns of probs.
* \param[in] topk Topk value.
* \param[in] coeff Coefficient.
* \param[out] aux_loss Output GPU scalar for auxiliary loss.
......@@ -105,24 +106,24 @@ void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_ou
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert,
int total_num_tokens, int num_tokens, int num_experts,
int topk, float coeff, NVTETensor aux_loss,
int total_num_tokens, int num_experts, int num_rows,
int num_cols, int topk, float coeff, NVTETensor aux_loss,
NVTETensor Const_buf, cudaStream_t stream);
/*! \brief Backward pass for auxiliary loss.
*
* \param[in] Const_buf Constant buffer from the forward pass.
* \param[in] tokens_per_expert Number of tokens per expert.
* \param[in] num_tokens Number of total tokens.
* \param[in] num_experts Number of experts.
* \param[in] num_rows Number of rows of probs.
* \param[in] num_cols Number of columns of probs.
* \param[in] grad_aux_loss Gradient of auxiliary loss.
* \param[out] grad_probs Gradient of probs.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf,
const NVTETensor tokens_per_expert, int num_tokens,
int num_experts, NVTETensor grad_aux_loss,
NVTETensor grad_probs, cudaStream_t stream);
const NVTETensor tokens_per_expert, int num_rows,
int num_cols, NVTETensor grad_aux_loss, NVTETensor grad_probs,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -38,11 +38,12 @@ at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts,
std::tuple<at::Tensor, at::Tensor> fused_moe_aux_loss_fwd(at::Tensor probs,
at::Tensor tokens_per_expert,
int total_num_tokens, int num_tokens,
int num_experts, int topk, float coeff);
int total_num_tokens, int num_experts,
int num_rows, int num_cols, int topk,
float coeff);
at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert,
int num_tokens, int num_experts, at::Tensor grad_aux_loss);
at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows,
int num_cols, at::Tensor grad_aux_loss);
/***************************************************************************************************
* Permutation
......
......@@ -278,11 +278,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("topk"), py::arg("score_function"), "Fused topk softmax bwd");
m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd,
py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"),
py::arg("num_tokens"), py::arg("num_experts"), py::arg("topk"), py::arg("coeff"),
"Fused aux loss fwd");
py::arg("num_experts"), py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"),
py::arg("coeff"), "Fused aux loss fwd");
m.def("fused_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_moe_aux_loss_bwd,
py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_tokens"),
py::arg("num_experts"), py::arg("grad_aux_loss"), "Fused aux loss bwd");
py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"),
py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd");
// Misc
m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version,
......
......@@ -145,8 +145,9 @@ at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts,
std::tuple<at::Tensor, at::Tensor> fused_moe_aux_loss_fwd(at::Tensor probs,
at::Tensor tokens_per_expert,
int total_num_tokens, int num_tokens,
int num_experts, int topk, float coeff) {
int total_num_tokens, int num_experts,
int num_rows, int num_cols, int topk,
float coeff) {
TORCH_CHECK(topk > 0, "topk must be greater than 0");
TORCH_CHECK(total_num_tokens > 0, "total_num_tokens must be greater than 0");
TORCH_CHECK(num_experts > 0, "num_experts must be greater than 0");
......@@ -161,17 +162,17 @@ std::tuple<at::Tensor, at::Tensor> fused_moe_aux_loss_fwd(at::Tensor probs,
auto Const_buf_cu = makeTransformerEngineTensor(Const_buf);
nvte_fused_moe_aux_loss_forward(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens,
num_tokens, num_experts, topk, coeff, aux_loss_cu.data(),
num_experts, num_rows, num_cols, topk, coeff, aux_loss_cu.data(),
Const_buf_cu.data(), at::cuda::getCurrentCUDAStream());
return std::make_tuple(aux_loss, Const_buf);
}
at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert,
int num_tokens, int num_experts, at::Tensor grad_aux_loss) {
at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows,
int num_cols, at::Tensor grad_aux_loss) {
// Create the output tensor
at::Tensor grad_probs = at::empty({num_tokens, num_experts},
at::dtype(grad_aux_loss.scalar_type()).device(at::kCUDA));
at::Tensor grad_probs =
at::empty({num_rows, num_cols}, at::dtype(grad_aux_loss.scalar_type()).device(at::kCUDA));
auto Const_buf_cu = makeTransformerEngineTensor(Const_buf);
auto tokens_per_expert_cu = makeTransformerEngineTensor(tokens_per_expert);
......@@ -179,8 +180,8 @@ at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_ex
auto grad_probs_cu = makeTransformerEngineTensor(grad_probs);
// Meta data for the kernel
nvte_fused_moe_aux_loss_backward(Const_buf_cu.data(), tokens_per_expert_cu.data(), num_tokens,
num_experts, grad_aux_loss_cu.data(), grad_probs_cu.data(),
nvte_fused_moe_aux_loss_backward(Const_buf_cu.data(), tokens_per_expert_cu.data(), num_rows,
num_cols, grad_aux_loss_cu.data(), grad_probs_cu.data(),
at::cuda::getCurrentCUDAStream());
return grad_probs;
......
......@@ -27,6 +27,12 @@ class FusedTopkScoreFunction(torch.autograd.Function):
expert_bias: torch.Tensor,
):
# pylint: disable=missing-function-docstring
# Save the shape of the logits
tensor_shape = logits.shape
logits = logits.view(-1, tensor_shape[-1])
# Get the metadata of the viewed logits
num_tokens = logits.size(0)
num_experts = logits.size(1)
probs, routing_map, intermediate_output = tex.fused_topk_with_score_function_fwd(
logits,
topk,
......@@ -37,9 +43,11 @@ class FusedTopkScoreFunction(torch.autograd.Function):
score_function,
expert_bias,
)
# Restore the shape
probs = probs.view(tensor_shape)
ctx.save_for_backward(routing_map, intermediate_output)
ctx.num_tokens = logits.size(0)
ctx.num_experts = logits.size(1)
ctx.num_tokens = num_tokens
ctx.num_experts = num_experts
ctx.use_pre_softmax = use_pre_softmax
ctx.topk = topk
ctx.scaling_factor = scaling_factor
......@@ -50,17 +58,23 @@ class FusedTopkScoreFunction(torch.autograd.Function):
def backward(ctx, grad_probs, _):
# pylint: disable=missing-function-docstring
routing_map, intermediate_output = ctx.saved_tensors
# Save the shape of the grad_probs
tensor_shape = grad_probs.shape
# Adjust the shape of the grad_probs to 2D shape
grad_probs = grad_probs.contiguous().view(-1, tensor_shape[-1])
grad_logits = tex.fused_topk_with_score_function_bwd(
ctx.num_tokens,
ctx.num_experts,
routing_map,
intermediate_output,
grad_probs.contiguous(),
grad_probs,
ctx.topk,
ctx.use_pre_softmax,
ctx.scaling_factor,
ctx.score_function,
)
# Restore the shape
grad_logits = grad_logits.view(tensor_shape)
return grad_logits, None, None, None, None, None, None, None
......@@ -124,6 +138,12 @@ class FusedComputeScoresForMoEAuxLoss(torch.autograd.Function):
score_function: str,
):
# pylint: disable=missing-function-docstring
# Save the shape of the logits
tensor_shape = logits.shape
logits = logits.view(-1, tensor_shape[-1])
# Get the metadata of the viewed logits
num_tokens = logits.size(0)
num_experts = logits.size(1)
scores, routing_map, intermediate_output = tex.fused_score_for_moe_aux_loss_fwd(
logits=logits,
topk=topk,
......@@ -132,22 +152,28 @@ class FusedComputeScoresForMoEAuxLoss(torch.autograd.Function):
ctx.save_for_backward(intermediate_output)
ctx.topk = topk
ctx.score_function = score_function
ctx.num_tokens = logits.size(0)
ctx.num_experts = logits.size(1)
ctx.num_tokens = num_tokens
ctx.num_experts = num_experts
return routing_map, scores
@staticmethod
def backward(ctx, _, grad_scores):
# pylint: disable=missing-function-docstring
intermediate_output = ctx.saved_tensors[0]
# Save the shape of the grad_scores
tensor_shape = grad_scores.shape
# Adjust the shape of the grad_scores to 2D shape
grad_scores = grad_scores.contiguous().view(-1, tensor_shape[-1])
grad_logits = tex.fused_score_for_moe_aux_loss_bwd(
num_tokens=ctx.num_tokens,
num_experts=ctx.num_experts,
intermediate_output=intermediate_output,
grad_scores=grad_scores.contiguous(),
grad_scores=grad_scores,
topk=ctx.topk,
score_function=ctx.score_function,
)
# Restore the shape
grad_logits = grad_logits.view(tensor_shape)
return grad_logits, None, None
......@@ -189,19 +215,21 @@ class FusedAuxLoss(torch.autograd.Function):
coeff: float,
):
# pylint: disable=missing-function-docstring
num_tokens = probs.size(0)
num_rows = probs.size(0)
num_cols = probs.size(1)
aux_loss, Const_buf = tex.fused_moe_aux_loss_fwd(
probs=probs,
tokens_per_expert=tokens_per_expert,
total_num_tokens=total_num_tokens,
num_tokens=num_tokens,
num_experts=num_experts,
num_rows=num_rows,
num_cols=num_cols,
topk=topk,
coeff=coeff,
)
ctx.save_for_backward(Const_buf, tokens_per_expert)
ctx.num_tokens = num_tokens
ctx.num_experts = num_experts
ctx.num_rows = num_rows
ctx.num_cols = num_cols
return aux_loss
@staticmethod
......@@ -211,8 +239,8 @@ class FusedAuxLoss(torch.autograd.Function):
grad_probs = tex.fused_moe_aux_loss_bwd(
Const_buf=Const_buf,
tokens_per_expert=tokens_per_expert,
num_tokens=ctx.num_tokens,
num_experts=ctx.num_experts,
num_rows=ctx.num_rows,
num_cols=ctx.num_cols,
grad_aux_loss=grad_aux_loss,
)
return grad_probs, None, None, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment