Unverified Commit a3353744 authored by NVJiangShao's avatar NVJiangShao Committed by GitHub
Browse files

[MoE][Common/PyTorch] Add permutation (#936)



* Add permutation functions

* Add permutation ops

* Remove the dependency on cutlass

* Move permutation.py out of module dir

* Rewrite the unit test and enable skipping if FP8 is unavailable

* Rename exposed C++ API and reorder its parameters + take NVTETensor as inputs

* Use Float8Tensor for FP8 input

* Move dtype to ctx

---------
Signed-off-by: default avatarJiang Shao <jiangs@nvidia.com>
Co-authored-by: default avatarQi Zhang <qizhang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 47caafb2
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import pytest
from typing import Dict, List
from transformer_engine.pytorch import moe_permute as te_permute, moe_unpermute as te_unpermute
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor
import transformer_engine_torch as tex
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def pytorch_permute(tokens, indices, num_out_tokens: int = None):
"""
Permute the tokens based on the indices. Token with the same index will be grouped together.
The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately.
Args:
tokens: torch.Tensor
The input token tensor.
indices: torch.Tensor
The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk].
num_out_tokens: int, optional
The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped.
By default, set to None, meaning no tokens are dropped.
Returns:
torch.Tensor:
The permuted tensor.
torch.Tensor:
The sorted_indices corresponding permuted tensor.
"""
if indices.dim() == 1:
topk = 1
else:
topk = indices.size(1)
flatten_indices = indices.view(-1)
sorted_indices = torch.argsort(flatten_indices, stable=True)
num_out_tokens = num_out_tokens if num_out_tokens is not None else flatten_indices.size(0)
permuted_tokens = tokens.index_select(0, sorted_indices[:num_out_tokens] // topk)
return permuted_tokens, sorted_indices
def pytorch_unpermute(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
probs: torch.Tensor = None,
):
"""
Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their
corresponding probabilities.
Args:
permuted_tokens: torch.Tensor
The tensor of permuted tokens to be unpermuted.
sorted_indices: torch.Tensor
The tensor of sorted indices used to unpermute the tokens.
probs: torch.Tensor, optional
The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will
be merged with their respective probabilities.
Returns:
torch.Tensor:
The unpermuted tokens, optionally merged with probabilities.
"""
if probs is not None:
# Unpermute and merge the tokens with their probabilities
num_unpermuted_tokens = probs.numel()
topk = probs.size(1)
else:
# Unpermute the tokens without merge
num_unpermuted_tokens = sorted_indices.size(0)
topk = 1
unpermuted_tokens = torch.zeros(
[num_unpermuted_tokens, permuted_tokens.shape[-1]],
dtype=permuted_tokens.dtype,
device=permuted_tokens.device,
)
unpermuted_tokens.index_copy_(0, sorted_indices[: permuted_tokens.size(0)], permuted_tokens)
unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
if probs is not None:
unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
unpermuted_tokens = unpermuted_tokens.sum(dim=1)
return unpermuted_tokens
def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]:
"""Estimated tolerances for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if te_dtype == tex.DType.kFloat32:
return dict(rtol=1.0e-6, atol=1.0e-6)
if te_dtype == tex.DType.kFloat16:
return dict(rtol=3.0e-3, atol=1.0e-5)
if te_dtype == tex.DType.kBFloat16:
return dict(rtol=2.0e-2, atol=1.0e-5)
if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3:
return dict(rtol=2.0e-1, atol=1.0e-1)
raise ValueError(f"Unsuppored dtype ({te_dtype})")
def _test_permutation(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
with_probs,
BENCHMARK=False,
):
if not with_probs and topK > 1:
pytest.skip("Only permutations with topK=1 and without probabilities are supported.")
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
if num_out_tokens == None:
num_out_tokens = num_tokens * topK
print(
f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
)
fp8 = False
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else:
pytest.skip("Invalid dtype.")
if fp8:
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
permute_bwd_input = torch.rand(
size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
unpermute_bwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
permute_fwd_input = Float8Tensor.to_float8(
permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)
)
permute_bwd_input = Float8Tensor.to_float8(
permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)
)
unpermute_bwd_input = Float8Tensor.to_float8(
unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)
)
pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16)
pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16)
pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16)
else:
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_fwd_input.requires_grad_(True)
if num_tokens > 0:
indices = torch.stack([torch.randperm(num_expert)[:topK] for _ in range(num_tokens)])
else:
indices = torch.empty((num_tokens, topK))
indices = indices.to(torch.int32).cuda()
probs = None
if with_probs:
probs = torch.rand(num_tokens, topK).cuda()
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
probs.requires_grad_(True)
###################################################################################################################################
#
# PyTorch Permutation
#
###################################################################################################################################
pytorch_permute_output, sorted_indices = pytorch_permute(
pytorch_permute_fwd_input, indices, num_out_tokens
)
pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True)
pytorch_unpermute_fwd_input = pytorch_permute_output.detach()
pytorch_unpermute_fwd_input.requires_grad_(True)
pytorch_unpermute_output = pytorch_unpermute(
pytorch_unpermute_fwd_input, sorted_indices, probs=probs
)
pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# TE Permutation
#
###################################################################################################################################
te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach()
te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach()
te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, te_dtype, indices, num_out_tokens
)
te_permute_output.backward(te_permute_bwd_input, retain_graph=True)
te_probs = None
if with_probs:
te_probs = probs.detach()
te_probs.requires_grad_(True)
te_unpermute_fwd_input = te_permute_output.detach()
te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()
te_unpermute_output = te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs)
te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
if fp8:
te_permute_output_ = te_permute_output.from_float8(torch.float32)
te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32)
te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32)
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32)
else:
te_permute_output_ = te_permute_output.float()
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
torch.testing.assert_close(
pytorch_permute_output.float(),
te_permute_output_,
msg=f"Mismatch in te_permute fwd",
)
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_fwd_input.grad.float(),
te_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if with_probs:
torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols
)
if not pytorch_permute_fwd_input.numel():
print("Empty pytorch_permute_fwd_input activation test passed.")
return
###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
def backward_wrapper(
act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False
):
# Set forward_input.grad to None to avoid grad accumulation.
if accumulate_grad == False:
for i in forward_input:
i.grad = None
return act.backward(backward_input, retain_graph=retain_graph)
if BENCHMARK:
t1 = perf_test_cuda_kernel(
lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens)
)
t2 = perf_test_cuda_kernel(
lambda: te_permute(te_permute_fwd_input, te_dtype, indices, num_out_tokens)
)
print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
pytorch_permute_output,
pytorch_permute_bwd_input,
forward_input=[pytorch_permute_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_permute_output,
te_permute_bwd_input,
forward_input=[te_permute_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs)
)
t2 = perf_test_cuda_kernel(
lambda: te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs)
)
print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
pytorch_unpermute_output,
pytorch_unpermute_bwd_input,
forward_input=(
[pytorch_unpermute_fwd_input, probs]
if with_probs
else [pytorch_unpermute_fwd_input]
),
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_unpermute_output,
te_unpermute_bwd_input,
forward_input=(
[te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input]
),
retain_graph=True,
accumulate_grad=False,
)
)
print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
def perf_test_cuda_kernel(cuda_kernel_fn):
if torch.cuda.is_available():
# create CUDA event
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# warmup
for _ in range(50):
cuda_kernel_fn()
start_event.record()
for _ in range(100):
cuda_kernel_fn()
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
return elapsed_time_ms / 100
else:
pytest.skip("CUDA is not available.")
# TE tensor dtypes
_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16]
if is_bf16_compatible():
_te_dtypes.append(tex.DType.kBFloat16)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
with_probs = True
BENCHMARK = False
_test_permutation(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_fp8(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
with_probs = True
BENCHMARK = False
_test_permutation(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_topk1_no_probs(
te_dtype,
num_tokens,
num_expert,
hidden_size,
):
topK = 1
num_out_tokens = None
with_probs = False
BENCHMARK = False
_test_permutation(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
def test_permutation_single_case():
print("GPU:", torch.cuda.get_device_name(0))
# te_dtype = tex.DType.kFloat32
# te_dtype = tex.DType.kFloat16
# te_dtype = tex.DType.kBFloat16
te_dtype = tex.DType.kFloat8E5M2
# te_dtype = tex.DType.kFloat8E4M3
num_tokens = 10
num_expert = 4
hidden_size = 16
topK = 2
num_out_tokens = num_tokens * topK - 1
with_probs = True
Benchmark = True
_test_permutation(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=Benchmark,
)
if __name__ == "__main__":
test_permutation_single_case()
......@@ -62,6 +62,7 @@ list(APPEND transformer_engine_SOURCES
layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu
layer_norm/ln_fwd_cuda_kernel.cu
permutation/permutation.cu
rmsnorm/rmsnorm_api.cpp
rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
rmsnorm/rmsnorm_fwd_cuda_kernel.cu
......
......@@ -255,7 +255,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"Unable to find suitable cuBLAS GEMM algorithm");
NVTE_CHECK_CUBLAS(status);
if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms");
if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");
// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_PERMUTATION_H_
#define TRANSFORMER_ENGINE_PERMUTATION_H_
#include "transformer_engine.h"
void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id,
NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad,
const NVTETensor input_fwd, const int num_rows, const int topK,
const int num_cols, const int num_out_tokens, cudaStream_t stream = nullptr);
void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map,
const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
cudaStream_t stream = nullptr);
#endif // TRANSFORMER_ENGINE_PERMUTATION_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/permutation.h>
#include "../common.h"
static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map,
const int num_rows, const int topK,
const int num_out_tokens) {
// Each block corresponds to one source token
// row_id_map[topK][num_rows]
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const int idx = bid * blockDim.x + tid;
if (idx >= num_rows * topK) return;
int source_row = sorted_row_id[idx];
int source_token_id = source_row / topK;
int source_topK_id = source_row % topK;
if (idx >= num_out_tokens) {
// Set the indices of dropped tokens to -1
row_id_map[source_topK_id * num_rows + source_token_id] = -1;
} else {
// Create a row id map for subsequent unpermute operation
row_id_map[source_topK_id * num_rows + source_token_id] = idx;
}
}
template <typename T, typename TCompute, bool hasProb>
__global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const int *row_id_map,
const float *prob, const int num_rows, const int topK,
const int num_cols) {
extern __shared__ int8_t s_mem[];
TCompute *s_prob = reinterpret_cast<TCompute *>(s_mem);
// Each block corresponds to one dest token
const int source_token = blockIdx.x;
const int tid = threadIdx.x;
if (hasProb) {
for (int i = tid; i < topK; i += blockDim.x * blockDim.y) {
// Load all the topK probs related to the source row into smem
s_prob[i] = TCompute(prob[source_token * topK + i]);
}
__syncthreads();
}
// Register buffers for vector type (float4) memory access
float4 frag_load_store;
T *frag_load_store_ptr = reinterpret_cast<T *>(&frag_load_store);
// Number of elemments in frag_load_store
static constexpr int kElementsPerAccess = 16 / sizeof(T);
// Traverse along the hidden dimention
for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) {
TCompute frag_elem[kElementsPerAccess];
TCompute frag_sum[kElementsPerAccess];
int source_row = row_id_map[source_token];
// source_row == -1 represents a dropped token
if (source_row != -1) {
const T *source_row_ptr = input + source_row * num_cols;
frag_load_store = __ldlu(reinterpret_cast<const float4 *>(source_row_ptr + i));
for (int e = 0; e < kElementsPerAccess; e++) {
frag_sum[e] = TCompute(frag_load_store_ptr[e]);
}
if (hasProb) {
for (int e = 0; e < kElementsPerAccess; e++) {
frag_sum[e] = frag_sum[e] * s_prob[0];
}
}
} else {
for (int e = 0; e < kElementsPerAccess; e++) {
frag_sum[e] = TCompute(0.0f);
}
}
for (int k = 1; k < topK; k++) {
source_row = row_id_map[k * num_rows + source_token];
if (source_row == -1) continue;
const T *source_row_ptr = input + source_row * num_cols;
frag_load_store = __ldlu(reinterpret_cast<const float4 *>(source_row_ptr + i));
for (int e = 0; e < kElementsPerAccess; e++) {
frag_elem[e] = TCompute(frag_load_store_ptr[e]);
}
if (hasProb) {
for (int e = 0; e < kElementsPerAccess; e++) {
frag_elem[e] = frag_elem[e] * s_prob[k];
}
}
for (int e = 0; e < kElementsPerAccess; e++) {
frag_sum[e] = frag_sum[e] + frag_elem[e];
}
}
T *dest_row_ptr = unpermuted_output + source_token * num_cols;
for (int e = 0; e < kElementsPerAccess; e++) {
if constexpr ((std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>) &&
(!hasProb)) {
frag_sum[e] = frag_sum[e] / TCompute(topK);
}
frag_load_store_ptr[e] = T(frag_sum[e]);
}
*reinterpret_cast<float4 *>(dest_row_ptr + i) = frag_load_store;
}
}
template <typename T, typename TCompute, int topKTile, bool hasProb>
__global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *act_grad,
const float *prob, float *prob_grad, const int *row_id_map,
const int num_rows, const int topK, const int num_cols) {
extern __shared__ int8_t s_mem[];
TCompute *s_prob = reinterpret_cast<TCompute *>(s_mem);
// Each block corresponds to one source token
const int source_token = blockIdx.x;
const int tid = threadIdx.x;
if (hasProb) {
for (int i = tid; i < topK; i += blockDim.x) {
// Load all the topK probs related to the source row into smem
s_prob[i] = TCompute(prob[source_token * topK + i]);
}
__syncthreads();
}
// Accumulators for the calculation of prob_grad
float accum[topKTile] = {0.0f};
// Register buffers for vector type (float4) memory access
float4 frag_load_store;
T *frag_load_store_ptr = reinterpret_cast<T *>(&frag_load_store);
// Number of elemments in frag_load_store
static constexpr int kElementsPerAccess = 16 / sizeof(T);
// The starting address of each source row
const T *source_row_ptr = input_bwd + source_token * num_cols;
// Traverse along the hidden dimention
for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) {
TCompute frag_src[kElementsPerAccess];
frag_load_store = __ldlu(reinterpret_cast<const float4 *>(source_row_ptr + i));
for (int e = 0; e < kElementsPerAccess; e++) frag_src[e] = TCompute(frag_load_store_ptr[e]);
int index = source_token;
// Process each row in the corresponding topK rows
for (int k = 0; k < topKTile; k++) {
if (k == topK) break;
int dest_row = row_id_map[index];
index += num_rows;
if (dest_row != -1) {
if (hasProb) {
// Calculate act_grad in unpermute bwd
for (int e = 0; e < kElementsPerAccess; e++)
frag_load_store_ptr[e] = T(frag_src[e] * s_prob[k]);
} else {
// permute fwd
for (int e = 0; e < kElementsPerAccess; e++) frag_load_store_ptr[e] = T(frag_src[e]);
}
T *dest_row_ptr = act_grad + dest_row * num_cols;
*reinterpret_cast<float4 *>(dest_row_ptr + i) = frag_load_store;
if (hasProb) {
// Inner product calculation for prob_grad in unpermute bwd
const T *input_fwd_ptr = input_fwd + dest_row * num_cols;
frag_load_store = __ldlu(reinterpret_cast<const float4 *>(input_fwd_ptr + i));
TCompute frag_input_fwd[kElementsPerAccess];
for (int e = 0; e < kElementsPerAccess; e++)
frag_input_fwd[e] = TCompute(frag_load_store_ptr[e]);
for (int e = 0; e < kElementsPerAccess; e++) {
accum[k] += static_cast<float>(frag_src[e] * frag_input_fwd[e]);
}
}
}
}
}
if (hasProb) {
for (int k = 0; k < topKTile; k++) {
if (k == topK) break;
// Warp-level reduction
for (int mask = 16; mask > 0; mask /= 2) {
accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32);
}
}
if (tid == 0) {
for (int k = 0; k < topKTile; k++) {
if (k == topK) break;
prob_grad[source_token * topK + k] = accum[k];
}
}
}
}
template <typename T>
void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, int *row_id_map,
const float *prob, float *prob_grad, const T *input_fwd,
const int num_rows, const int topK, const int num_cols,
const int num_out_tokens, cudaStream_t stream) {
using TCompute = typename std::conditional<(std::is_same<T, __nv_fp8_e5m2>::value ||
std::is_same<T, __nv_fp8_e4m3>::value),
half, T>::type;
static constexpr int kElementsPerAccess = 16 / sizeof(T);
if (input_fwd == nullptr) {
// moe_permute_fwd
int threads = 64;
int blocks = (num_rows * topK + threads - 1) / threads;
moe_permute_row_map<<<blocks, threads, 0, stream>>>(sorted_row_id, row_id_map, num_rows, topK,
num_out_tokens);
blocks = num_rows;
threads = std::min(num_cols / kElementsPerAccess, 1024);
moe_permute_kernel<T, TCompute, 128, false><<<blocks, threads, 0, stream>>>(
input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols);
} else {
// moe_unpermute_bwd
int threads = 32;
int blocks = num_rows;
if (prob == nullptr) {
// moe_unpermute_bwd without probs
moe_permute_kernel<T, TCompute, 1, false><<<blocks, threads, 0, stream>>>(
input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols);
} else {
// moe_unpermute_bwd with probs
size_t smem_bytes = topK * sizeof(TCompute);
if (topK <= 8) {
moe_permute_kernel<T, TCompute, 8, true><<<blocks, threads, smem_bytes, stream>>>(
input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols);
} else if (topK <= 16) {
moe_permute_kernel<T, TCompute, 16, true><<<blocks, threads, smem_bytes, stream>>>(
input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols);
} else if (topK <= 32) {
moe_permute_kernel<T, TCompute, 32, true><<<blocks, threads, smem_bytes, stream>>>(
input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols);
} else if (topK <= 64) {
moe_permute_kernel<T, TCompute, 64, true><<<blocks, threads, smem_bytes, stream>>>(
input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols);
} else if (topK <= 128) {
moe_permute_kernel<T, TCompute, 128, true><<<blocks, threads, smem_bytes, stream>>>(
input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols);
} else {
NVTE_ERROR("topK cannot exceed 128.");
}
}
}
}
template <typename T>
void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const float *prob,
const int num_rows, const int topK, const int num_cols,
cudaStream_t stream) {
using TCompute = typename std::conditional<(std::is_same<T, __nv_fp8_e5m2>::value ||
std::is_same<T, __nv_fp8_e4m3>::value),
half, T>::type;
static constexpr int kElementsPerAccess = 16 / sizeof(T);
int blocks = num_rows;
int threads = std::min(num_cols / kElementsPerAccess, 1024);
size_t smem_bytes = topK * sizeof(TCompute);
if (prob == nullptr) {
// moe_permute_bwd
// moe_unpermute_fwd without probs
moe_unpermute_kernel<T, TCompute, false><<<blocks, threads, smem_bytes, stream>>>(
input, output, row_id_map, nullptr, num_rows, topK, num_cols);
} else {
// moe_unpermute_fwd with probs
moe_unpermute_kernel<T, TCompute, true><<<blocks, threads, smem_bytes, stream>>>(
input, output, row_id_map, prob, num_rows, topK, num_cols);
}
}
void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id,
NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad,
const NVTETensor input_fwd, const int num_rows, const int topK,
const int num_cols, const int num_out_tokens, cudaStream_t stream) {
NVTE_API_CALL(nvte_permute);
const transformer_engine::Tensor *input_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input);
const transformer_engine::Tensor *output_cu =
reinterpret_cast<const transformer_engine::Tensor *>(output);
const transformer_engine::Tensor *sorted_row_id_cu =
reinterpret_cast<const transformer_engine::Tensor *>(sorted_row_id);
const transformer_engine::Tensor *row_id_map_cu =
reinterpret_cast<const transformer_engine::Tensor *>(row_id_map);
const transformer_engine::Tensor *prob_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob);
const transformer_engine::Tensor *prob_grad_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob_grad);
const transformer_engine::Tensor *input_fwd_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input_fwd);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_cu->data.dtype, T,
nvte_permute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr),
reinterpret_cast<T *>(output_cu->data.dptr),
reinterpret_cast<const int *>(sorted_row_id_cu->data.dptr),
reinterpret_cast<int *>(row_id_map_cu->data.dptr),
reinterpret_cast<const float *>(prob_cu->data.dptr),
reinterpret_cast<float *>(prob_grad_cu->data.dptr),
reinterpret_cast<const T *>(input_fwd_cu->data.dptr), num_rows, topK,
num_cols, num_out_tokens, stream););
}
void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map,
const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
cudaStream_t stream) {
NVTE_API_CALL(nvte_unpermute);
const transformer_engine::Tensor *input_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input);
const transformer_engine::Tensor *output_cu =
reinterpret_cast<const transformer_engine::Tensor *>(output);
const transformer_engine::Tensor *row_id_map_cu =
reinterpret_cast<const transformer_engine::Tensor *>(row_id_map);
const transformer_engine::Tensor *prob_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_cu->data.dtype, T,
nvte_unpermute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr),
reinterpret_cast<T *>(output_cu->data.dptr),
reinterpret_cast<int *>(row_id_map_cu->data.dptr),
reinterpret_cast<const float *>(prob_cu->data.dptr), num_rows, topK,
num_cols, stream););
}
......@@ -44,6 +44,7 @@ from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformer_engine.pytorch.graph import make_graphed_callables
......
......@@ -28,6 +28,7 @@
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/permutation.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/softmax.h>
......
......@@ -10,6 +10,26 @@
#include "common.h"
#include "common/common.h"
/***************************************************************************************************
* Permutation
**************************************************************************************************/
std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices,
int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num);
at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
int64_t topK);
at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
int64_t topK);
std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob);
/***************************************************************************************************
* Attention
**************************************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cub/cub.cuh>
#include "extensions.h"
std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices,
int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num) {
const int num_tokens = input.size(0);
int num_cols = input.size(1);
const int topK = indices.size(1);
// Initialize the workspace on the first run
if (workspace.empty()) {
auto options =
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false);
at::Tensor sorted_indices = torch::empty(max_expanded_token_num, options);
at::Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options);
at::Tensor sorted_row_id =
torch::empty(max_expanded_token_num,
torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
size_t temp_storage_bytes = 0;
int *temp_ptr = nullptr;
cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_ptr, temp_ptr, temp_ptr,
temp_ptr, max_expanded_token_num);
at::Tensor temp_storage = torch::empty(
temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
workspace.push_back(sorted_indices);
workspace.push_back(row_id);
workspace.push_back(sorted_row_id);
workspace.push_back(temp_storage);
}
int *indices_ptr = reinterpret_cast<int *>(getDataPtr(indices, 0));
int *sorted_indices_ptr = reinterpret_cast<int *>(getDataPtr(workspace[0], 0));
int *row_id_ptr = reinterpret_cast<int *>(getDataPtr(workspace[1], 0));
int *sorted_row_id_ptr = reinterpret_cast<int *>(getDataPtr(workspace[2], 0));
void *d_temp_storage = getDataPtr(workspace[3], 0);
size_t temp_storage_bytes = std::numeric_limits<size_t>::max();
cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, indices_ptr,
sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr,
num_tokens * topK);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input.scalar_type();
// Output buffer alloc
num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK;
at::Tensor permuted_output = torch::empty(
{num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
at::Tensor row_id_map = torch::empty(
{num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto input_cu = makeTransformerEngineTensor(
input.data_ptr(), {static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)}, dtype);
auto permuted_output_cu = makeTransformerEngineTensor(
permuted_output.data_ptr(),
{static_cast<size_t>(permuted_output.size(0)), static_cast<size_t>(num_cols)}, dtype);
auto sorted_row_id_cu =
makeTransformerEngineTensor(sorted_row_id_ptr, {static_cast<size_t>(num_tokens * topK)},
transformer_engine::DType::kInt32);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(),
row_id_map_cu.data(), transformer_engine::TensorWrapper().data(),
transformer_engine::TensorWrapper().data(),
transformer_engine::TensorWrapper().data(), num_tokens, topK, num_cols,
num_out_tokens, stream);
return std::make_tuple(permuted_output, row_id_map, workspace);
}
at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
int64_t topK) {
return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK);
}
at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
int64_t topK) {
int num_cols = input.size(1);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input.scalar_type();
// Output buffer alloc
at::Tensor unpermuted_output = torch::empty(
{num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto input_cu = makeTransformerEngineTensor(
input.data_ptr(), {static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)}, dtype);
auto unpermuted_output_cu = makeTransformerEngineTensor(
unpermuted_output.data_ptr(),
{static_cast<size_t>(unpermuted_output.size(0)), static_cast<size_t>(num_cols)}, dtype);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
auto prob_cu = makeTransformerEngineTensor(prob);
nvte_unpermute(input_cu.data(), unpermuted_output_cu.data(), row_id_map_cu.data(), prob_cu.data(),
num_tokens, topK, num_cols, stream);
return unpermuted_output;
}
std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob) {
const int topK = (prob.numel() > 0) ? prob.size(1) : 1;
const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0);
int num_cols = input_bwd.size(1);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input_bwd.scalar_type();
// Output buffer alloc
at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols},
torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
at::Tensor prob_grad = torch::empty(
{num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto input_bwd_cu = makeTransformerEngineTensor(
input_bwd.data_ptr(), {static_cast<size_t>(input_bwd.size(0)), static_cast<size_t>(num_cols)},
dtype);
auto act_grad_cu = makeTransformerEngineTensor(
act_grad.data_ptr(), {static_cast<size_t>(act_grad.size(0)), static_cast<size_t>(num_cols)},
dtype);
auto input_fwd_cu = makeTransformerEngineTensor(
input_fwd.data_ptr(), {static_cast<size_t>(input_fwd.size(0)), static_cast<size_t>(num_cols)},
dtype);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
auto prob_cu = makeTransformerEngineTensor(prob);
auto prob_grad_cu = makeTransformerEngineTensor(prob_grad);
nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), transformer_engine::TensorWrapper().data(),
row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(),
num_tokens, topK, num_cols, 0, stream);
return std::make_tuple(act_grad, prob_grad);
}
......@@ -10,6 +10,12 @@
#include "../extensions.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Permutation functions
m.def("moe_permute_fwd", moe_permute_fwd);
m.def("moe_permute_bwd", moe_permute_bwd);
m.def("moe_unpermute_fwd", moe_unpermute_fwd);
m.def("moe_unpermute_bwd", moe_unpermute_bwd);
// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD",
py::call_guard<py::gil_scoped_release>());
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Linear API"""
import warnings
from typing import Tuple
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
__all__ = [
"moe_permute",
"moe_unpermute",
]
class _moe_permute(torch.autograd.Function):
"""functional Permute"""
workspace = None
max_expanded_token_num = 0
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
dtype: tex.DType,
indices: torch.Tensor,
num_out_tokens: int,
max_token_num: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Empty input check
if not inp.numel():
return inp, None
# Device check
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert indices.is_cuda, "TransformerEngine needs CUDA."
# Shape check
assert inp.size(0) == indices.size(0), "Permute not possible"
# Data type check
fp8 = False
if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
fp8 = True
if fp8:
assert isinstance(
inp, Float8Tensor
), "Input must be in Float8Tensor type for FP8 moe_permute."
fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
inp = inp._data
if indices.dtype != torch.int32:
warnings.warn(
f"The data type of the input `indices` of Permute is {indices.dtype}! "
"The recommended type is torch.int32."
)
indices = indices.to(torch.int32)
topK = indices.size(1)
input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK
if _moe_permute.max_expanded_token_num < input_max_expanded_token_num:
_moe_permute.max_expanded_token_num = input_max_expanded_token_num
_moe_permute.workspace = []
permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd(
inp,
dtype,
indices,
num_out_tokens,
_moe_permute.workspace,
_moe_permute.max_expanded_token_num,
)
if fp8:
permuted_act = Float8Tensor(
data=permuted_act, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
)
ctx.row_id_map = row_id_map
ctx.num_tokens = indices.size(0)
ctx.topK = indices.size(1)
ctx.dtype = dtype
ctx.fp8 = fp8
return permuted_act, row_id_map
@staticmethod
def backward(
ctx,
permuted_act_grad: torch.Tensor,
_,
) -> Tuple[torch.Tensor, ...]:
# Empty input check
if not permuted_act_grad.numel():
return permuted_act_grad, None, None, None
if not permuted_act_grad.is_contiguous():
permuted_act_grad = permuted_act_grad.contiguous()
fp8 = ctx.fp8
if fp8:
assert isinstance(
permuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
fp8_dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
permuted_act_grad = permuted_act_grad._data
row_id_map = ctx.row_id_map
num_tokens = ctx.num_tokens
topK = ctx.topK
act_grad = None
if ctx.needs_input_grad[0]:
act_grad = tex.moe_permute_bwd(
permuted_act_grad, ctx.dtype, row_id_map, torch.empty(0), num_tokens, topK
)
if fp8:
act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv * topK
)
return act_grad, None, None, None, None
class _moe_unpermute(torch.autograd.Function):
"""functional Unpermute"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
dtype: tex.DType,
row_id_map: torch.Tensor,
probs: torch.Tensor,
) -> torch.Tensor:
# Empty input check
if not inp.numel():
ctx.probs = probs
return inp
# None probs check
if probs is not None:
assert probs.is_cuda, "TransformerEngine needs CUDA."
if probs.dtype != torch.float32:
warnings.warn(
f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
"The recommended type is torch.float32."
)
probs = probs.to(torch.float32)
num_tokens = probs.size(0)
topK = probs.size(1)
else:
num_tokens = row_id_map.size(0)
topK = 1
probs = torch.empty(0)
# Device check
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
# Data type check
fp8 = False
if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
fp8 = True
if fp8:
assert isinstance(
inp, Float8Tensor
), "Input must be in Float8Tensor type for FP8 moe_unpermute."
fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
inp = inp._data
if row_id_map.dtype != torch.int32:
warnings.warn(
f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
"The recommended type is torch.int32."
)
row_id_map = row_id_map.to(torch.int32)
unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)
if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
)
ctx.dtype = dtype
ctx.save_for_backward(inp, row_id_map, probs)
ctx.fp8 = fp8
return unpermuted_output
@staticmethod
def backward(
ctx,
unpermuted_act_grad: torch.Tensor,
) -> Tuple[torch.Tensor, None, torch.Tensor]:
# Empty input check
if not unpermuted_act_grad.numel():
return unpermuted_act_grad, None, ctx.probs
if not unpermuted_act_grad.is_contiguous():
unpermuted_act_grad = unpermuted_act_grad.contiguous()
fp8 = ctx.fp8
if fp8:
assert isinstance(
unpermuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
fp8_dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
unpermuted_act_grad = unpermuted_act_grad._data
inp, row_id_map, probs = ctx.saved_tensors
act_grad = None
if ctx.needs_input_grad[0]:
act_grad, prob_grad = tex.moe_unpermute_bwd(
unpermuted_act_grad, inp, ctx.dtype, row_id_map, probs
)
if fp8:
act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
)
if not ctx.needs_input_grad[3]:
prob_grad = None
return act_grad, None, None, prob_grad
def moe_permute(
inp: torch.Tensor,
dtype: tex.DType,
indices: torch.Tensor,
num_out_tokens: int = -1,
max_token_num: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Permute the tokens based on the indices. Token with the same index will be grouped together.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
dtype: tex.DType
Data type of the input tensor.
indices: torch.Tensor
The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'.
num_out_tokens: int, default = -1
The effective output token count, representing the number of tokens not dropped.
By default, set to '-1', meaning no tokens are dropped.
max_token_num: int, default = -1
The maximum number of tokens, used for workspace allocation.
By default, set to '-1', meaning the calculation of the size of workspace is
automatically taken over by the operator.
"""
return _moe_permute.apply(inp, dtype, indices, num_out_tokens, max_token_num)
def moe_unpermute(
inp: torch.Tensor,
dtype: tex.DType,
row_id_map: torch.Tensor,
probs: torch.Tensor = None,
) -> torch.Tensor:
"""
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
corresponding probabilities.
Parameters
----------
inp: torch.Tensor
Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
dtype: tex.DType
Data type of the input tensor.
row_id_map: torch.Tensor
The tensor of a mapping table for sorted indices used to unpermute the tokens,
which is the second output tensor of `Permute`.
probs: torch.Tensor
The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
"""
return _moe_unpermute.apply(inp, dtype, row_id_map, probs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment