Commit 544dd14b authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Update main branch with TE 2.0 code, update version to 2.1.0.dev0


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e5369541
...@@ -81,6 +81,26 @@ int sm_count(int device_id) { ...@@ -81,6 +81,26 @@ int sm_count(int device_id) {
return cache[device_id]; return cache[device_id];
} }
void stream_priority_range(int *low_priority, int *high_priority, int device_id) {
static std::vector<std::pair<int, int>> cache(num_devices());
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&]() {
int ori_dev = current_device();
if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(device_id));
int min_pri, max_pri;
NVTE_CHECK_CUDA(cudaDeviceGetStreamPriorityRange(&min_pri, &max_pri));
if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(ori_dev));
cache[device_id] = std::make_pair(min_pri, max_pri);
};
std::call_once(flags[device_id], init);
*low_priority = cache[device_id].first;
*high_priority = cache[device_id].second;
}
bool supports_multicast(int device_id) { bool supports_multicast(int device_id) {
#if CUDART_VERSION >= 12010 #if CUDART_VERSION >= 12010
// NOTE: This needs to be guarded at compile time because the // NOTE: This needs to be guarded at compile time because the
......
...@@ -38,6 +38,16 @@ int sm_arch(int device_id = -1); ...@@ -38,6 +38,16 @@ int sm_arch(int device_id = -1);
*/ */
int sm_count(int device_id = -1); int sm_count(int device_id = -1);
/* \brief Minimum and maximum stream priorities supported on device
*
* \param[in] device_id CUDA device (default is current device)
*
* \param[out] low_priority Lowest priority value on device.
*
* \param[out] high_priority Highest priority value on device.
*/
void stream_priority_range(int *low_priority, int *high_priority, int device_id = -1);
/* \brief CUDA Multicast support status for device /* \brief CUDA Multicast support status for device
* *
* \param[in] device_id CUDA device (default is current device) * \param[in] device_id CUDA device (default is current device)
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dequantize_kernels.cuh
* \brief CUDA kernels to cast from MXFP8.
*/
#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_
#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
#include <cfloat>
#include <limits>
#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transpose.h"
namespace transformer_engine {
namespace dequantization {
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 128;
constexpr size_t BUFFERS_NUM = 2;
constexpr size_t ELEMS_PER_THREAD = 16;
constexpr size_t BUFFER_DIM_Y = 16; // only 32 is supported
constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128
constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 16
constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128
constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; // 8 = 128 / 16
constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; // 128
constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16
static_assert(ITERATIONS >= 1);
template <typename IType, typename OType, size_t SCALE_DIM_Y, size_t SCALE_DIM_X>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
const e8m0_t *const scales_ptr, const size_t rows, const size_t cols,
const size_t scales_stride) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1;
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32
constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32
constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128
constexpr size_t THREADS_PER_SCALE_X_ROWWISE =
DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16
constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y;
const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X;
const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y;
const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X;
const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE;
const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE;
// const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE;
const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE;
const int thread_offset_Y = tid_rowwise_Y;
const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD;
// const int thread_offset_X_colwise = tid_colwise_X;
// The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned
__shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
__shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM;
constexpr int transaction_size = shmem_buff_size;
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS];
if (is_master_thread) {
// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate.
#pragma unroll
for (int iter = 0; iter < ITERATIONS; ++iter) {
ptx::mbarrier_init(&mbar[iter], THREADS_PER_CHUNK);
}
ptx::fence_proxy_async_shared_cta();
}
// Syncthreads so initialized barrier is visible to all threads.
__syncthreads();
int parity = 0;
constexpr int iteration_zero = 0;
constexpr int buffer_zero = 0;
if (is_master_thread) {
const int chunk_stage_offset_Y = chunk_offset_Y;
const int chunk_stage_offset_X = chunk_offset_X;
// Initiate bulk tensor copy
ptx::cp_async_bulk_tensor_2d_global_to_shared(
reinterpret_cast<uint64_t *>(&in_sh[buffer_zero]),
reinterpret_cast<const uint64_t *>(&tensor_map_input), chunk_stage_offset_X,
chunk_stage_offset_Y, &mbar[iteration_zero]);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(&mbar[iteration_zero], transaction_size);
} else {
// Other threads just arrive
ptx::mbarrier_arrive(&mbar[iteration_zero]);
}
#pragma unroll
for (int iter = 0; iter < ITERATIONS; ++iter) {
const int buff = iter % BUFFERS_NUM;
const int next_iter = iter + 1;
if (next_iter < ITERATIONS) {
if (is_master_thread) {
const int next_buff = next_iter % BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_iter * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
// Initiate bulk tensor copy
ptx::cp_async_bulk_tensor_2d_global_to_shared(
reinterpret_cast<uint64_t *>(&in_sh[next_buff]),
reinterpret_cast<const uint64_t *>(&tensor_map_input), chunk_it_offset_x,
chunk_it_offset_y, &mbar[next_iter]);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(&mbar[next_iter], transaction_size);
} else {
// Other threads just arrive
ptx::mbarrier_arrive(&mbar[next_iter]);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[iter], parity);
const int scale_offset_Y =
USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * BUFFER_DIM_Y + tid_rowwise_Y)
: (scales_colwise_chunk_offset_Y + (iter * BUFFER_DIM_Y) / SCALE_DIM_Y);
const int scale_offset_X =
USE_ROWWISE_SCALING
? (scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE)
: (scales_colwise_chunk_offset_X + tid_colwise_X);
const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X;
const e8m0_t biased_exponent = scales_ptr[scale_idx];
const float block_scale = exp2f(static_cast<float>(biased_exponent) - FP32_EXPONENT_BIAS);
if constexpr (USE_ROWWISE_SCALING) {
Vec<IType, ELEMS_PER_THREAD> in;
Vec<OType, ELEMS_PER_THREAD> out;
const int shmem_offset_y = thread_offset_Y;
const int shmem_offset_x = thread_offset_X_rowwise;
in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]);
#pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) {
out.data.elt[j] = static_cast<OType>(block_scale * static_cast<float>(in.data.elt[j]));
}
out.store_to(&out_sh[buff][shmem_offset_y][shmem_offset_x]);
} else {
#pragma unroll
for (int i = 0; i < BUFFER_DIM_Y; ++i) {
const float elt = static_cast<float>(in_sh[buff][i][tid_colwise_X]);
out_sh[buff][i][tid_colwise_X] = static_cast<OType>(block_scale * elt);
}
}
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_sh[buff]));
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<1>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
parity ^= 1;
// Destroy barrier. This invalidates the memory region of the barrier. If
// further computations were to take place in the kernel, this allows the
// memory location of the shared memory barrier to be reused.
if (is_master_thread) {
#pragma unroll
for (int iter = 0; iter < ITERATIONS; ++iter) {
ptx::mbarrier_invalid(&mbar[iter]);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(OType);
detail::DequantizeParam p;
p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<const IType *>(input.data.dptr), nullptr,
reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, nullptr, N, p,
stream);); // NOLINT(*)
); // NOLINT(*)
}
static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
bool use_rowwise_scaling = input.has_data();
bool use_colwise_scaling = input.has_columnwise_data();
checkCuDriverContext(stream);
const auto &input_shape = input.data.shape;
NVTE_CHECK(input_shape.size() >= 2, "Input must have at least 2 dimensions.");
if (use_rowwise_scaling) {
NVTE_CHECK(input.has_data(), "Cannot dequantize tensor without rowwise data.");
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
}
if (use_colwise_scaling) {
NVTE_CHECK(input.has_columnwise_data(), "Cannot dequantize tensor without columnwise data.");
NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type.");
}
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
// TODO: Make more general
const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1;
const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1;
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X);
const size_t unpadded_scales_Y_rowwise = rows;
const size_t unpadded_scales_X_rowwise = DIVUP(cols, scale_dim_X_rowwise);
const size_t unpadded_scales_Y_colwise = DIVUP(rows, scale_dim_Y_colwise);
const size_t unpadded_scales_X_colwise = cols;
const size_t scales_Y_rowwise =
DIVUP(unpadded_scales_Y_rowwise, scale_tensor_alignment_Y_rowwise) *
scale_tensor_alignment_Y_rowwise;
const size_t scales_X_rowwise =
DIVUP(unpadded_scales_X_rowwise, scale_tensor_alignment_X_rowwise) *
scale_tensor_alignment_X_rowwise;
const size_t scales_Y_colwise =
DIVUP(unpadded_scales_Y_colwise, scale_tensor_alignment_Y_colwise) *
scale_tensor_alignment_Y_colwise;
const size_t scales_X_colwise =
DIVUP(unpadded_scales_X_colwise, scale_tensor_alignment_X_colwise) *
scale_tensor_alignment_X_colwise;
const e8m0_t *const scales_ptr =
use_rowwise_scaling ? reinterpret_cast<e8m0_t *>(input.scale_inv.dptr)
: reinterpret_cast<e8m0_t *>(input.columnwise_scale_inv.dptr);
const size_t scales_stride = use_rowwise_scaling ? scales_X_rowwise : scales_X_colwise;
const SimpleTensor &input_data = use_rowwise_scaling ? input.data : input.columnwise_data;
const dim3 block(THREADS_PER_CHUNK);
const dim3 grid(chunks_X, chunks_Y);
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_Y_colwise, SCALE_DIM_Y,
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_X_rowwise, SCALE_DIM_X,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output{};
create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(IType));
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(OType));
dequantize_mxfp8_kernel<IType, OType, SCALE_DIM_Y, SCALE_DIM_X>
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_output, scales_ptr,
rows, cols, scales_stride);); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace dequantization
namespace detail {
void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
if (is_tensor_scaling(input.scaling_mode)) {
dequantization::fp8_dequantize(input, output, stream);
} else if (is_mxfp_scaling(input.scaling_mode)) {
if (is_supported_by_CC_100()) {
dequantization::mxfp8_dequantize(input, output, stream);
} else {
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
}
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
}
}
} // namespace detail
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file ptx.cuh
* \brief BW PTX
*/
#ifndef TRANSFORMER_ENGINE_PTX_CUH_
#define TRANSFORMER_ENGINE_PTX_CUH_
#include <cuda.h>
#include <cuda_runtime.h>
namespace transformer_engine {
namespace ptx {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
__device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) : "memory");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval
__device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
__device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) : "memory");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
__device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), "r"(tx_count)
: "memory");
}
__device__ __forceinline__ void fence_mbarrier_init_release_cluster() {
asm volatile("fence.mbarrier_init.release.cluster;");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// global -> shared::cluster
__device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared(
uint64_t *dst_shmem, const uint64_t *src_global_ptr, const uint32_t size, uint64_t *mbar) {
uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem);
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
// triggers async copy, i.e. the thread continues until wait() on mbarrier
// barrier condition:
// - leader must arrive (i.e. 1 thread as set above)
// - TMA hardware substracts bytes from expect_tx counter, must reach zero
asm volatile(
"cp.async.bulk.shared::cta.global"
".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"(dst_shmem_ptr),
"l"(src_global_ptr), "r"(size), "r"(mbar_ptr)
: "memory");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// global -> shared::cluster
__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, const uint32_t offset_x,
const uint32_t offset_y, uint64_t *mbar) {
uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem);
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
// triggers async copy, i.e. the thread continues until wait() on mbarrier
// barrier condition:
// - leader must arrive (i.e. 1 thread as set above)
// - TMA hardware substracts bytes from expect_tx counter, must reach zero
asm volatile(
"cp.async.bulk.tensor.2d.shared::cluster.global.tile"
".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"(dst_shmem_ptr),
"l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr)
: "memory");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global
__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr,
const uint64_t *src_shmem,
const uint32_t size) {
uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem);
asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"(dst_global_ptr),
"r"(src_shmem_ptr), "r"(size)
: "memory");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global
__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global(
const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y,
uint64_t *src_shmem) {
uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem);
asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" ::"l"(
tensor_map_ptr),
"r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr)
: "memory");
}
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) {
uint32_t waitComplete;
asm volatile(
"{\n\t .reg .pred P_OUT; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P_OUT; \n"
"}"
: "=r"(waitComplete)
: "r"(mbar_ptr), "r"(parity)
: "memory");
return static_cast<bool>(waitComplete);
}
__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
while (!mbarrier_try_wait_parity(mbar_ptr, parity)) {
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() {
asm volatile("cp.async.bulk.commit_group;");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
__device__ __forceinline__ void cp_async_bulk_wait_group() {
asm volatile("cp.async.bulk.wait_group 0;");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
template <size_t W>
__device__ __forceinline__ void cp_async_bulk_wait_group_read() {
asm volatile("cp.async.bulk.wait_group.read 0;");
}
template <>
__device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() {
asm volatile("cp.async.bulk.wait_group.read 0;");
}
template <>
__device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() {
asm volatile("cp.async.bulk.wait_group.read 1;");
}
template <>
__device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() {
asm volatile("cp.async.bulk.wait_group.read 2;");
}
template <>
__device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() {
asm volatile("cp.async.bulk.wait_group.read 4;");
}
// Proxy fence (bi-directional):
__device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); }
__device__ __forceinline__ void fence_proxy_async_shared_cta() {
asm volatile("fence.proxy.async.shared::cta;");
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} // namespace ptx
namespace {
template <int num_barriers, int THREADS_PER_BLOCK>
__forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool is_master_thread) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
if (is_master_thread) {
// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate.
#pragma unroll
for (int iter = 0; iter < num_barriers; ++iter) {
ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK);
}
ptx::fence_proxy_async_shared_cta();
}
// Syncthreads so initialized barrier is visible to all threads.
__syncthreads();
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template <int num_barriers>
__forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_master_thread) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Destroy barrier. This invalidates the memory region of the barrier. If
// further computations were to take place in the kernel, this allows the
// memory location of the shared memory barrier to be reused.
if (is_master_thread) {
#pragma unroll
for (int iter = 0; iter < num_barriers; ++iter) {
ptx::mbarrier_invalid(&mbar[iter]);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src,
const size_t num_bytes, uint64_t *barrier,
const bool is_master_thread) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
if (is_master_thread) {
// Initiate bulk tensor copy
ptx::cp_async_bulk_tensor_1d_global_to_shared(reinterpret_cast<uint64_t *>(dst),
reinterpret_cast<const uint64_t *>(src),
num_bytes, barrier);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(barrier, num_bytes);
} else {
// Other threads just arrive
ptx::mbarrier_arrive(barrier);
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X,
const size_t chunk_Y, const size_t num_bytes,
uint64_t *barrier, const bool is_master_thread) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
if (is_master_thread) {
// Initiate bulk tensor copy
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(dst),
reinterpret_cast<const uint64_t *>(src), chunk_X,
chunk_Y, barrier);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(barrier, num_bytes);
} else {
// Other threads just arrive
ptx::mbarrier_arrive(barrier);
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src,
const size_t chunk_X1, const size_t chunk_Y1,
void *dst2, const void *src2,
const size_t chunk_X2, const size_t chunk_Y2,
const size_t num_bytes, uint64_t *barrier,
const bool is_master_thread) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
if (is_master_thread) {
// Initiate bulk tensor copy
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(dst),
reinterpret_cast<const uint64_t *>(src), chunk_X1,
chunk_Y1, barrier);
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(dst2),
reinterpret_cast<const uint64_t *>(src2),
chunk_X2, chunk_Y2, barrier);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(barrier, 2 * num_bytes);
} else {
// Other threads just arrive
ptx::mbarrier_arrive(barrier);
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__forceinline__ __device__ void copy_2d_to_sharedx3(
void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, void *dst2,
const void *src2, const size_t chunk_X2, const size_t chunk_Y2, void *dst3, const void *src3,
const size_t chunk_X3, const size_t chunk_Y3, const size_t num_bytes, uint64_t *barrier,
const bool is_master_thread) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
if (is_master_thread) {
// Initiate bulk tensor copy
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(dst),
reinterpret_cast<const uint64_t *>(src), chunk_X1,
chunk_Y1, barrier);
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(dst2),
reinterpret_cast<const uint64_t *>(src2),
chunk_X2, chunk_Y2, barrier);
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(dst3),
reinterpret_cast<const uint64_t *>(src3),
chunk_X3, chunk_Y3, barrier);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(barrier, 3 * num_bytes);
} else {
// Other threads just arrive
ptx::mbarrier_arrive(barrier);
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_PTX_CUH_
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "cuda_runtime.h" #include "cuda_runtime.h"
#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \
pybind11::enum_<transformer_engine::DType>(m, "DType") \ pybind11::enum_<transformer_engine::DType>(m, "DType", pybind11::module_local()) \
.value("kByte", transformer_engine::DType::kByte) \ .value("kByte", transformer_engine::DType::kByte) \
.value("kInt32", transformer_engine::DType::kInt32) \ .value("kInt32", transformer_engine::DType::kInt32) \
.value("kFloat32", transformer_engine::DType::kFloat32) \ .value("kFloat32", transformer_engine::DType::kFloat32) \
...@@ -23,12 +23,12 @@ ...@@ -23,12 +23,12 @@
.value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type") \ pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \
.value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \
pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type") \ pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type", pybind11::module_local()) \
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout") \ pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) \
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \
.value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \
...@@ -52,15 +52,17 @@ ...@@ -52,15 +52,17 @@
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend") \ pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType") \ pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \
pybind11::module_local()) \
.value("RS", transformer_engine::CommOverlapType::RS) \ .value("RS", transformer_engine::CommOverlapType::RS) \
.value("AG", transformer_engine::CommOverlapType::AG); \ .value("AG", transformer_engine::CommOverlapType::AG); \
pybind11::enum_<transformer_engine::CommOverlapAlgo>(m, "CommOverlapAlgo") \ pybind11::enum_<transformer_engine::CommOverlapAlgo>(m, "CommOverlapAlgo", \
pybind11::module_local()) \
.value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \
.value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \
.value("SPLIT_PIPELINED_AG_P2P", \ .value("SPLIT_PIPELINED_AG_P2P", \
...@@ -71,8 +73,38 @@ ...@@ -71,8 +73,38 @@
.value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \
.value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \
.value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \
py::class_<transformer_engine::CommOverlapCore, \
std::shared_ptr<transformer_engine::CommOverlapCore>>(m, "CommOverlapCore", \
pybind11::module_local()) \
.def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \
py::call_guard<py::gil_scoped_release>()) \
.def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \
py::call_guard<py::gil_scoped_release>()) \
.def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \
py::call_guard<py::gil_scoped_release>()) \
.def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \
py::call_guard<py::gil_scoped_release>()); \
py::class_<transformer_engine::CommOverlapBase, \
std::shared_ptr<transformer_engine::CommOverlapBase>, \
transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \
.def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \
py::call_guard<py::gil_scoped_release>()); \
py::class_<transformer_engine::CommOverlapP2PBase, \
std::shared_ptr<transformer_engine::CommOverlapP2PBase>, \
transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \
pybind11::module_local()) \
.def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \
py::call_guard<py::gil_scoped_release>()); \
m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \
py::call_guard<py::gil_scoped_release>(), py::arg("device_id") = -1); \ py::call_guard<py::gil_scoped_release>(), py::arg("device_id") = -1); \
m.def( \
"get_stream_priority_range", \
[](int device_id = -1) { \
int low_pri, high_pri; \
transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \
return std::make_pair(low_pri, high_pri); \
}, \
py::call_guard<py::gil_scoped_release>(), py::arg("device_id") = -1); \
m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
......
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
#include <string> #include <string>
#include "../common.h"
namespace transformer_engine { namespace transformer_engine {
/*! \brief Get environment variable and convert to type /*! \brief Get environment variable and convert to type
......
...@@ -44,6 +44,13 @@ class VectorizedStorage { ...@@ -44,6 +44,13 @@ class VectorizedStorage {
return *this; return *this;
} }
inline __device__ ~VectorizedStorage() {} inline __device__ ~VectorizedStorage() {}
/* \brief Access to separate elements. */
inline __device__ DType *separate() { return scratch_.separate; }
inline __device__ const DType *separate() const { return scratch_.separate; }
inline __device__ LType &aligned() { return scratch_.aligned; }
}; };
// Returns const LType is DType is const // Returns const LType is DType is const
...@@ -167,9 +174,11 @@ constexpr int unary_kernel_threads = 512; ...@@ -167,9 +174,11 @@ constexpr int unary_kernel_threads = 512;
template <int nvec, bool aligned, typename ComputeType, typename Param, template <int nvec, bool aligned, typename ComputeType, typename Param,
ComputeType (*OP)(ComputeType, const Param &), typename InputType, typename OutputType> ComputeType (*OP)(ComputeType, const Param &), typename InputType, typename OutputType>
__launch_bounds__(unary_kernel_threads) __global__ __launch_bounds__(unary_kernel_threads) __global__
void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale, void unary_kernel(const InputType *input, const ComputeType *noop, OutputType *output,
ComputeType *amax, ComputeType *scale_inv, Param p, const size_t N, const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, Param p,
const size_t num_aligned_elements) { const size_t N, const size_t num_aligned_elements) {
if (noop != nullptr && noop[0] == 1.0f) return;
VectorizedLoader<InputType, nvec, aligned> loader(input, N); VectorizedLoader<InputType, nvec, aligned> loader(input, N);
VectorizedStorer<OutputType, nvec, aligned> storer(output, N); VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0; ComputeType max = 0;
...@@ -322,9 +331,9 @@ Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) ...@@ -322,9 +331,9 @@ Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs)
template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typename InputType, template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typename InputType,
typename OutputType> typename OutputType>
void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output,
fp32 *amax, fp32 *scale_inv, const size_t N, const Param params, const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N,
cudaStream_t stream) { const Param params, cudaStream_t stream) {
if (N != 0) { if (N != 0) {
auto align = CheckAlignment(N, nvec, input, output); auto align = CheckAlignment(N, nvec, input, output);
...@@ -337,16 +346,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, c ...@@ -337,16 +346,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, c
switch (align) { switch (align) {
case Alignment::SAME_ALIGNED: case Alignment::SAME_ALIGNED:
unary_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>( unary_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, scale_inv, params, N, num_aligned_elements); input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements);
break; break;
case Alignment::SAME_UNALIGNED: case Alignment::SAME_UNALIGNED:
unary_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>( unary_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, scale_inv, params, N, num_aligned_elements); input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements);
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
unary_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>( unary_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, scale_inv, params, N, N); input, noop, output, scale, amax, scale_inv, params, N, N);
break; break;
} }
} }
...@@ -395,12 +404,6 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -395,12 +404,6 @@ __launch_bounds__(unary_kernel_threads) __global__
ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n,
const Param p, const size_t num_aligned_elements) { const Param p, const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m; const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements;
const size_t id_y = tid / num_aligned_elements;
VectorizedLoader<InputType, nvec, aligned> loader0(input + id_y * n * 2, n);
VectorizedLoader<InputType, nvec, aligned> loader1(input + id_y * n * 2 + n, n);
VectorizedStorer<OutputType, nvec, aligned> storer(output + id_y * n, n);
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 1; ComputeType s = 1;
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
...@@ -408,6 +411,13 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -408,6 +411,13 @@ __launch_bounds__(unary_kernel_threads) __global__
} }
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements;
const size_t id_y = tid / num_aligned_elements;
VectorizedLoader<InputType, nvec, aligned> loader0(input + id_y * n * 2, n);
VectorizedLoader<InputType, nvec, aligned> loader1(input + id_y * n * 2 + n, n);
VectorizedStorer<OutputType, nvec, aligned> storer(output + id_y * n, n);
loader0.load(id_x, n); loader0.load(id_x, n);
loader1.load(id_x, n); loader1.load(id_x, n);
#pragma unroll #pragma unroll
...@@ -423,7 +433,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -423,7 +433,7 @@ __launch_bounds__(unary_kernel_threads) __global__
storer.separate()[i] = static_cast<OutputType>(static_cast<ComputeType>(temp)); storer.separate()[i] = static_cast<OutputType>(static_cast<ComputeType>(temp));
} }
storer.store(id_x, n); storer.store(id_x, n);
}
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
// Reduce amax over block // Reduce amax over block
if (amax != nullptr) { if (amax != nullptr) {
...@@ -439,7 +449,6 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -439,7 +449,6 @@ __launch_bounds__(unary_kernel_threads) __global__
reciprocal<ComputeType>(scale_inv, s); reciprocal<ComputeType>(scale_inv, s);
} }
} }
}
} }
template <int nvec, typename ComputeType, typename Param, template <int nvec, typename ComputeType, typename Param,
...@@ -482,9 +491,17 @@ template <int nvec, bool aligned, typename ComputeType, typename Param, ...@@ -482,9 +491,17 @@ template <int nvec, bool aligned, typename ComputeType, typename Param,
typename OutputType> typename OutputType>
__launch_bounds__(unary_kernel_threads) __global__ __launch_bounds__(unary_kernel_threads) __global__
void dgated_act_kernel(const InputType *grad, const InputType *input, OutputType *output, void dgated_act_kernel(const InputType *grad, const InputType *input, OutputType *output,
const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv,
const size_t m, const size_t n, const Param p, const size_t m, const size_t n, const Param p,
const size_t num_aligned_elements) { const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m; const size_t M = num_aligned_elements * m;
ComputeType max = 0;
ComputeType s = 1;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
}
const int warp_id = threadIdx.x / THREADS_PER_WARP;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements; const size_t id_x = tid % num_aligned_elements;
const size_t id_y = tid / num_aligned_elements; const size_t id_y = tid / num_aligned_elements;
...@@ -507,12 +524,35 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -507,12 +524,35 @@ __launch_bounds__(unary_kernel_threads) __global__
ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;
ComputeType after_dgate = grad_val * Activation(gelu_in, p); ComputeType after_dgate = grad_val * Activation(gelu_in, p);
if constexpr (is_fp8<OutputType>::value) {
__builtin_assume(max >= 0);
max = fmaxf(fabsf(after_dgelu), max);
after_dgelu = after_dgelu * s;
max = fmaxf(fabsf(after_dgate), max);
after_dgate = after_dgate * s;
}
storer0.separate()[i] = static_cast<OutputType>(after_dgelu); storer0.separate()[i] = static_cast<OutputType>(after_dgelu);
storer1.separate()[i] = static_cast<OutputType>(after_dgate); storer1.separate()[i] = static_cast<OutputType>(after_dgate);
} }
storer0.store(id_x, n); storer0.store(id_x, n);
storer1.store(id_x, n); storer1.store(id_x, n);
} }
if constexpr (is_fp8<OutputType>::value) {
// Reduce amax over block
if (amax != nullptr) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
}
}
} }
template <int nvec, typename ComputeType, typename Param, template <int nvec, typename ComputeType, typename Param,
...@@ -520,8 +560,9 @@ template <int nvec, typename ComputeType, typename Param, ...@@ -520,8 +560,9 @@ template <int nvec, typename ComputeType, typename Param,
ComputeType (*Dactivation)(const ComputeType, const Param &), typename InputType, ComputeType (*Dactivation)(const ComputeType, const Param &), typename InputType,
typename OutputType> typename OutputType>
void DGatedActivationKernelLauncher(const InputType *grad, const InputType *input, void DGatedActivationKernelLauncher(const InputType *grad, const InputType *input,
OutputType *output, const size_t m, const size_t n, OutputType *output, const fp32 *scale, fp32 *amax,
const Param &p, cudaStream_t stream) { fp32 *scale_inv, const size_t m, const size_t n, const Param &p,
cudaStream_t stream) {
if (m != 0 && n != 0) { if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, sizeof(InputType)); size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, sizeof(InputType));
constexpr size_t threads = unary_kernel_threads; constexpr size_t threads = unary_kernel_threads;
...@@ -532,18 +573,19 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu ...@@ -532,18 +573,19 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu
switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) { switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) {
case Alignment::SAME_ALIGNED: case Alignment::SAME_ALIGNED:
dgated_act_kernel<nvec, true, ComputeType, Param, Activation, Dactivation> dgated_act_kernel<nvec, true, ComputeType, Param, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, <<<num_blocks, threads, 0, stream>>>(grad, input, output, scale, amax, scale_inv, m, n,
num_aligned_elements); p, num_aligned_elements);
break; break;
case Alignment::SAME_UNALIGNED: case Alignment::SAME_UNALIGNED:
dgated_act_kernel<nvec, false, ComputeType, Param, Activation, Dactivation> dgated_act_kernel<nvec, false, ComputeType, Param, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, <<<num_blocks, threads, 0, stream>>>(grad, input, output, scale, amax, scale_inv, m, n,
num_aligned_elements); p, num_aligned_elements);
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation> dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, p, n); <<<num_blocks, threads, 0, stream>>>(grad, input, output, scale, amax, scale_inv, m, n,
p, n);
break; break;
} }
} }
......
...@@ -819,6 +819,21 @@ __device__ __forceinline__ float warp_reduce_max(const float m) { ...@@ -819,6 +819,21 @@ __device__ __forceinline__ float warp_reduce_max(const float m) {
return tmp; return tmp;
} }
__forceinline__ __device__ float warp_reduce_max_broadcast(const float val) {
float val_tmp = val;
#pragma unroll
for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) {
const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset);
__builtin_assume(val_tmp >= 0);
__builtin_assume(val_other >= 0);
val_tmp = fmaxf(val_tmp, val_other);
}
// Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
constexpr int subwarp_lane_zero = 0;
val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero);
return val_tmp;
}
template <int num_warps, typename compute_t> template <int num_warps, typename compute_t>
__device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) { __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) {
__shared__ float staging[num_warps]; __shared__ float staging[num_warps];
...@@ -837,6 +852,29 @@ __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int war ...@@ -837,6 +852,29 @@ __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int war
return result; return result;
} }
/**
* Max reduction in subwarps
* E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four MXFP8 scaling factors.
* To compute an actual scaling factor for 32 consequentive elements, only 8 threads need to participate,
* thus splitting the warp into 4x smaller subwarps 8-thread width.
* 'Butterfly' reduction is used inside subwarps.
*/
template <int subwarp_width>
__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) {
float val_tmp = val;
#pragma unroll
for (int offset = subwarp_width / 2; offset > 0; offset /= 2) {
const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width);
__builtin_assume(val_tmp >= 0);
__builtin_assume(val_other >= 0);
val_tmp = fmaxf(val_tmp, val_other);
}
// Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
constexpr int subwarp_lane_zero = 0;
val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width);
return val_tmp;
}
// Works only on positive values // Works only on positive values
__device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) { __device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) {
atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value)); atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value));
...@@ -857,6 +895,79 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float ...@@ -857,6 +895,79 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float
*value_inv = __frcp_rn(value); *value_inv = __frcp_rn(value);
} }
////////////////////////////////////////////////////////////////////////////////////////////////////
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using e8m0_t = uint8_t;
constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 };
template <typename T>
struct Numeric_Traits;
template <>
struct Numeric_Traits<fp8e4m3> {
static constexpr int maxUnbiasedExponent = 8;
static constexpr double maxNorm = 448;
};
template <>
struct Numeric_Traits<fp8e5m2> {
static constexpr int maxUnbiasedExponent = 15;
static constexpr double maxNorm = 57344;
};
template <typename T>
struct Quantized_Limits {
static constexpr int max_unbiased_exponent = Numeric_Traits<T>::maxUnbiasedExponent;
static constexpr float max_norm = Numeric_Traits<T>::maxNorm;
static constexpr float max_norm_rcp = 1.0 / max_norm;
static constexpr float emax = 1 << max_unbiased_exponent;
static constexpr float emax_rcp = 1.0 / emax;
};
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
// TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax.
if (isnan(val)) {
return 0xFF;
}
if (isinf(val)) {
return 0xFE;
}
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t out;
asm volatile(
"{\n"
"cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n"
"}"
: "=h"(out)
: "f"(val));
return *reinterpret_cast<e8m0_t *>(&out);
#else
if (val == 0.0f) {
return 0x00;
}
uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
uint32_t mantissa = val_u32 & 0x7FFFFF;
// Round up exponent and deal with satfinite.
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
++exponent;
}
return exponent;
#endif
}
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(biased_exp));
}
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "extensions.h" #include "extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/transpose.h" #include "transformer_engine/transpose.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
...@@ -332,17 +333,26 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_ ...@@ -332,17 +333,26 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size}; auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size}; auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); // Evil hack to specify TE impl
auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype); // Note: nvte_quantize_dbias_dgelu chooses its internal impl based
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); // on what pointers are allocated, e.g. whether to output with
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); // column-wise data. However, we don't have access to any allocated
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); // buffers in this function. We pass a dummy pointer as a
// workaround.
int temp = 0;
auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto dact_input_tensor =
TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
auto output_tensor = TensorWrapper();
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_trans_shape);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
TensorWrapper dummy_workspace; TensorWrapper dummy_workspace;
// For now, all dbias_dact(-s) have the same workspace size // For now, all dbias_dact(-s) have the same workspace size
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr); dbias_tensor.data(), dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape()); auto work_shape = MakeShapeVector(dummy_workspace.shape());
...@@ -384,36 +394,31 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -384,36 +394,31 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor = auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor = output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape);
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
switch (act_enum) { switch (act_enum) {
case NVTE_Activation_Type::GELU: case NVTE_Activation_Type::GELU:
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream); dbias_tensor.data(), workspace.data(), stream);
break; break;
case NVTE_Activation_Type::SILU: case NVTE_Activation_Type::SILU:
nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream); dbias_tensor.data(), workspace.data(), stream);
break; break;
case NVTE_Activation_Type::RELU: case NVTE_Activation_Type::RELU:
nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(), nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream); dbias_tensor.data(), workspace.data(), stream);
break; break;
case NVTE_Activation_Type::QGELU: case NVTE_Activation_Type::QGELU:
nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream); dbias_tensor.data(), workspace.data(), stream);
break; break;
case NVTE_Activation_Type::SRELU: case NVTE_Activation_Type::SRELU:
nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace.data(), stream); dbias_tensor.data(), workspace.data(), stream);
break; break;
default: default:
...@@ -468,36 +473,31 @@ Error_Type DActLuDBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_bu ...@@ -468,36 +473,31 @@ Error_Type DActLuDBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_bu
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor = output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum); auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
switch (act_type) { switch (act_type) {
case NVTE_Activation_Type::GELU: case NVTE_Activation_Type::GELU:
nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream); dbias_tensor.data(), workspace_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::SILU: case NVTE_Activation_Type::SILU:
nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream); dbias_tensor.data(), workspace_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::RELU: case NVTE_Activation_Type::RELU:
nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(), nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream); dbias_tensor.data(), workspace_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::QGELU: case NVTE_Activation_Type::QGELU:
nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream); dbias_tensor.data(), workspace_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::SRELU: case NVTE_Activation_Type::SRELU:
nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_tensor.data(), output_trans_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream); dbias_tensor.data(), workspace_tensor.data(), stream);
break; break;
default: default:
...@@ -555,29 +555,29 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -555,29 +555,29 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor = auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor = output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape);
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
switch (act_enum) { switch (act_enum) {
case NVTE_Activation_Type::GEGLU: case NVTE_Activation_Type::GEGLU:
nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream); stream);
break; break;
case NVTE_Activation_Type::SWIGLU: case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), stream); output_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::REGLU: case NVTE_Activation_Type::REGLU:
nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream); stream);
break; break;
case NVTE_Activation_Type::QGEGLU: case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), stream); output_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::SREGLU: case NVTE_Activation_Type::SREGLU:
nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), stream); output_tensor.data(), stream);
break; break;
default: default:
NVTE_ERROR("Unsupported ActivationEnum"); NVTE_ERROR("Unsupported ActivationEnum");
...@@ -622,30 +622,30 @@ Error_Type DGatedActLuCastTransposeFFI(cudaStream_t stream, Buffer_Type input_bu ...@@ -622,30 +622,30 @@ Error_Type DGatedActLuCastTransposeFFI(cudaStream_t stream, Buffer_Type input_bu
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor = output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
auto act_type = static_cast<NVTE_Activation_Type>(act_enum); auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
switch (act_type) { switch (act_type) {
case NVTE_Activation_Type::GEGLU: case NVTE_Activation_Type::GEGLU:
nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream); stream);
break; break;
case NVTE_Activation_Type::SWIGLU: case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), stream); output_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::REGLU: case NVTE_Activation_Type::REGLU:
nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream); stream);
break; break;
case NVTE_Activation_Type::QGEGLU: case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), stream); output_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::SREGLU: case NVTE_Activation_Type::SREGLU:
nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), output_trans_tensor.data(), stream); output_tensor.data(), stream);
break; break;
default: default:
NVTE_ERROR("Unsupported ActivationEnum"); NVTE_ERROR("Unsupported ActivationEnum");
......
...@@ -25,7 +25,7 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op ...@@ -25,7 +25,7 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype); auto input_tensor = TensorWrapper(input, shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv); auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv);
nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
} }
Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
...@@ -48,7 +48,7 @@ Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type a ...@@ -48,7 +48,7 @@ Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type a
auto input_tensor = TensorWrapper(input, shape, in_dtype); auto input_tensor = TensorWrapper(input, shape, in_dtype);
auto output_tensor = TensorWrapper(output, shape, out_dtype, amax_out, scale, scale_inv); auto output_tensor = TensorWrapper(output, shape, out_dtype, amax_out, scale, scale_inv);
nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -76,7 +76,7 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t ...@@ -76,7 +76,7 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv); auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype); auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); nvte_dequantize(input_tensor.data(), output_tensor.data(), stream);
} }
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
...@@ -96,7 +96,7 @@ Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type ...@@ -96,7 +96,7 @@ Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type
auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv); auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, out_dtype); auto output_tensor = TensorWrapper(output, shape, out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); nvte_dequantize(input_tensor.data(), output_tensor.data(), stream);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "transformer_engine/transpose.h" #include "transformer_engine/transpose.h"
#include "extensions.h" #include "extensions.h"
#include "transformer_engine/cast.h"
#include "xla/ffi/api/ffi.h" #include "xla/ffi/api/ffi.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -89,13 +90,12 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size ...@@ -89,13 +90,12 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
auto input_trans_shape = std::vector<size_t>{n, m}; auto input_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto input_cast_tensor = auto output_tensor =
TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv); TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape, desc.out_dtype, output_tensor.set_columnwise_data(input_cast_trans, desc.out_dtype, input_trans_shape);
amax_out, scale, scale_inv); output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(), nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
stream);
} }
Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
...@@ -131,11 +131,11 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -131,11 +131,11 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor = output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
nvte_cast_transpose(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
stream);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -159,15 +159,22 @@ pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hi ...@@ -159,15 +159,22 @@ pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hi
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size}; auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size}; auto dbias_shape = std::vector<size_t>{hidden_size};
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); // Evil hack to specify TE impl
auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); // Note: nvte_quantize_dbias chooses its internal impl based on what
auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); // pointers are allocated, e.g. whether to output with column-wise
auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); // data. However, we don't have access to any allocated buffers in
// this function. We pass a dummy pointer as a workaround.
int temp = 0;
auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto output_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), output_shape, out_dtype);
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_trans_shape);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
TensorWrapper dummy_workspace; TensorWrapper dummy_workspace;
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr); dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape()); auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
...@@ -203,14 +210,14 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -203,14 +210,14 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto output_tensor = auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor = output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape);
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
dbias_tensor.data(), workspace.data(), stream); workspace.data(), stream);
} }
Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
...@@ -253,13 +260,13 @@ Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buf ...@@ -253,13 +260,13 @@ Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buf
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor = output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream); workspace_tensor.data(), stream);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
......
...@@ -354,11 +354,6 @@ def fp8_autocast( ...@@ -354,11 +354,6 @@ def fp8_autocast(
assert ( assert (
fp8_recipe.scaling_factor_compute_algo is None fp8_recipe.scaling_factor_compute_algo is None
), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX." ), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
assert fp8_recipe.override_linear_precision == (
False,
False,
False,
), "DelayedScaling override_linear_precision isn't supported by TE/JAX."
assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX." assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX."
if mesh_resource is None: if mesh_resource is None:
......
recursive-include build_tools *.*
recursive-include common_headers *.*
recursive-include csrc *.*
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for Paddle"""
# pylint: disable=wrong-import-position,wrong-import-order
import logging
from importlib.metadata import version
from transformer_engine.common import is_package_installed
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_paddle"
if is_package_installed(module_name):
assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
assert is_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install"
" transformer-engine[paddle]==VERSION'"
)
if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name):
logging.info(
"Could not find package %s. Install transformer-engine using 'pip"
" install transformer-engine[paddle]==VERSION'",
module_name,
)
from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import
_load_library()
from .fp8 import fp8_autocast
from .layer import (
Linear,
LayerNorm,
LayerNormLinear,
LayerNormMLP,
FusedScaleMaskSoftmax,
DotProductAttention,
MultiHeadAttention,
TransformerLayer,
RotaryPositionEmbedding,
)
from .recompute import recompute
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Constants"""
from enum import Enum
import paddle
from transformer_engine import transformer_engine_paddle as tex
class FP8FwdTensors(Enum):
"""Used as named indices on the `scale`, `scale_inv`,
and `amax` tensors in the `FP8TensorMeta` class."""
GEMM1_INPUT = 0
GEMM1_WEIGHT = 1
GEMM1_OUTPUT = 2
GEMM2_INPUT = 3
GEMM2_WEIGHT = 4
GEMM2_OUTPUT = 5
class FP8BwdTensors(Enum):
"""Used as named indices on the `scale`, `scale_inv`,
and `amax` tensors in the `FP8TensorMeta` class."""
GRAD_OUTPUT1 = 0
GRAD_INPUT1 = 1
GRAD_OUTPUT2 = 2
GRAD_INPUT2 = 3
"""
Map from paddle dtype to TE dtype
"""
TE_DType = {
paddle.uint8: tex.DType.kByte,
paddle.int32: tex.DType.kInt32,
paddle.float32: tex.DType.kFloat32,
paddle.float16: tex.DType.kFloat16,
paddle.bfloat16: tex.DType.kBFloat16,
}
AttnMaskTypes = ("causal", "padding", "no_mask")
AttnTypes = ("self", "cross")
LayerTypes = ("encoder", "decoder")
GemmParallelModes = ("row", "column", None)
dist_group_type = paddle.distributed.collective.Group
RecomputeFunctionNames = ("unpack", "backward")
AttnBiasType = {
"no_bias": tex.NVTE_Bias_Type.NVTE_NO_BIAS,
"pre_scale_bias": tex.NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS,
"post_scale_bias": tex.NVTE_Bias_Type.NVTE_POST_SCALE_BIAS,
}
AttnMaskType = {
"no_mask": tex.NVTE_Mask_Type.NVTE_NO_MASK,
"padding": tex.NVTE_Mask_Type.NVTE_PADDING_MASK,
"causal": tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK,
}
FusedAttnBackend = {
"F16_max512_seqlen": tex.NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen,
"F16_arbitrary_seqlen": tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
"No_Backend": tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""TE FP8 extensions and GEMMs"""
import math
from typing import Optional, Tuple, Union
import paddle
import paddle.nn.functional as F
from transformer_engine import transformer_engine_paddle as tex
from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors
from .fp8 import FP8TensorMeta, get_global_fp8_state
BACKEND_F16m512_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16
def gemm(
A: paddle.Tensor,
B: paddle.Tensor,
dtype: paddle.dtype,
workspace: paddle.Tensor,
gelu: bool = False,
gelu_input: Optional[paddle.Tensor] = None,
grad: bool = False,
accumulate: bool = False,
layout: str = "TN",
out: Optional[paddle.Tensor] = None,
out_dtype: Optional[paddle.dtype] = None,
bias: Optional[paddle.Tensor] = None,
use_bias: bool = False,
) -> Tuple[Union[paddle.Tensor, None], ...]:
"""Non FP8 GEMM."""
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T"
transb = layout[1] == "T"
if out is None:
if accumulate:
out = paddle.zeros(
shape=[
B.shape[1] if transb else B.shape[0],
A.shape[0] if transa else A.shape[1],
],
dtype=out_dtype if out_dtype is not None else dtype,
)
else:
out = paddle.empty(
shape=[
B.shape[1] if transb else B.shape[0],
A.shape[0] if transa else A.shape[1],
],
dtype=out_dtype if out_dtype is not None else dtype,
)
if gelu and not grad:
gelu_input = paddle.empty_like(out, dtype=dtype)
elif not gelu:
gelu_input = None
if grad and use_bias:
grad_bias = paddle.empty(shape=[B.shape[1]], dtype=out.dtype)
else:
grad_bias = None
bias = bias if use_bias else None
assert (
A.dtype == dtype and B.dtype == dtype
), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}"
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out.dtype]
if use_bias:
bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
else:
bias_dtype = output_dtype
tex.te_gemm(
A,
None,
B,
None,
grad_bias if grad else bias,
out,
None, # out_scale
None, # out_amax
gelu_input,
workspace,
0, # A_index
0, # B_index
0, # D_index
int(input_dtype),
int(input_dtype),
int(output_dtype),
int(bias_dtype),
transa,
transb,
grad,
workspace.shape[0],
accumulate,
False, # use_split_accumulator
0, # math_sm_count
)
return out, grad_bias, gelu_input
def fp8_gemm(
A: paddle.Tensor,
A_scale_inv: paddle.Tensor,
A_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
A_dtype: tex.DType,
B: paddle.Tensor,
B_scale_inv: paddle.Tensor,
B_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
B_dtype: tex.DType,
out_dtype: paddle.dtype,
workspace: paddle.Tensor,
gelu: bool = False,
accumulate: bool = False,
out: Optional[paddle.Tensor] = None,
out_index=None,
fp8_meta_tensor: FP8TensorMeta = None,
bias: Optional[paddle.Tensor] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> paddle.Tensor:
"""TN layout GEMM with fp8 inputs."""
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_index is not None
if out is None:
if accumulate:
out = paddle.zeros(
shape=[
B.shape[0],
A.shape[0],
],
dtype=out_dtype,
)
else:
out = paddle.empty(
shape=[
B.shape[0],
A.shape[0],
],
dtype=out_dtype,
)
# Use bfloat16 as default bias_dtype
bias_dtype = paddle.bfloat16 if bias is None else bias.dtype
if gelu:
gelu_input = paddle.empty_like(out, dtype=bias_dtype)
else:
gelu_input = None
bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
tex.te_gemm(
A,
A_scale_inv,
B,
B_scale_inv,
bias if use_bias else None,
out,
None if out_index is None else fp8_meta_tensor.scale,
None if out_index is None else fp8_meta_tensor.amax_history,
gelu_input, # this is pre_gelu_out
workspace,
A_fp8_tensor.value,
B_fp8_tensor.value,
0 if out_index is None else out_index,
int(A_dtype),
int(B_dtype),
int(out_dtype),
int(bias_dtype),
True, # transa
False, # transb
False, # grad
workspace.shape[0],
accumulate,
use_split_accumulator,
0, # math_sm_count
)
return out, gelu_input
def cast_to_fp8(
inp: paddle.Tensor,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
"""Cast input to FP8"""
if out is None:
out = paddle.empty(
shape=inp.shape,
dtype=paddle.uint8,
)
else:
assert out.shape == inp.shape, "Output shape does not match input shape."
assert out.dtype == paddle.uint8, "Output should be of uint8 dtype."
tex.cast_to_fp8(
inp,
fp8_meta_tensor.scale,
out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
int(otype),
)
return out
def cast_from_fp8(
inp: paddle.Tensor,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
itype: tex.DType,
otype: tex.DType,
) -> paddle.Tensor:
"""Cast input from FP8"""
return tex.cast_from_fp8(
inp,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
int(itype),
int(otype),
)
def transpose(
inp: paddle.Tensor,
otype: tex.DType,
) -> paddle.Tensor:
"""Transpose input"""
return tex.te_transpose(
inp,
int(otype),
)
def cast_transpose(
inp: paddle.Tensor,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
cast_out: Optional[paddle.Tensor] = None,
transpose_out: Optional[paddle.Tensor] = None,
) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]:
"""Cast + Transpose with FP8 output"""
if cast_out is None:
cast_out = paddle.empty(
shape=inp.shape,
dtype=paddle.uint8,
)
else:
assert cast_out.shape == inp.shape, "cast_out shape does not match input shape."
assert cast_out.dtype == paddle.uint8, "cast_out should be of uint8 dtype."
if transpose_out is None:
transpose_out = paddle.empty(
shape=[inp.shape[1], inp.shape[0]],
dtype=paddle.uint8,
)
else:
assert transpose_out.shape == [
inp.shape[1],
inp.shape[0],
], "Transposed output shape does not match input shape."
assert transpose_out.dtype == paddle.uint8, "Output should be of uint8 dtype."
tex.te_cast_transpose(
inp,
fp8_meta_tensor.scale,
cast_out,
transpose_out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
int(otype),
)
return cast_out, transpose_out
def cast_transpose_bgrad(
inp: paddle.Tensor,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
) -> Union[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor], None]:
"""Fused Cast + Transpose + Bias Grad"""
grad_bias, cast_out, transpose_out, _, _ = tex.te_cast_transpose_bgrad(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
int(otype),
)
return grad_bias, cast_out, transpose_out
def te_gelu(
inp: paddle.Tensor,
otype: tex.DType,
) -> paddle.Tensor:
"""Non FP8 GELU"""
return tex.te_gelu(
inp,
int(otype),
)
def gelu_fp8(
inp: paddle.Tensor,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
) -> paddle.Tensor:
"""GELU + FP8 cast"""
out, _, _ = tex.te_gelu_fp8(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
int(otype),
)
return out
def swiglu(
inp: paddle.Tensor,
otype: tex.DType,
) -> paddle.Tensor:
"""Non FP8 SWIGLU"""
return tex.te_swiglu(
inp,
int(otype),
)
def swiglu_pd(
inp: paddle.Tensor,
) -> paddle.Tensor:
"""Native SWIGLU"""
gate_out, up_out = paddle.chunk(inp, chunks=2, axis=-1)
out = F.silu(gate_out) * up_out
return out
def swiglu_fp8(
inp: paddle.Tensor,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
) -> paddle.Tensor:
"""SWIGLU + FP8 cast"""
out, _, _ = tex.te_swiglu_fp8(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
int(otype),
)
return out
def dswiglu(
grad_output: paddle.Tensor,
swiglu_input: paddle.Tensor,
otype: tex.DType,
) -> paddle.Tensor:
"""dSWIGLU"""
return tex.te_dswiglu(
grad_output,
swiglu_input,
int(otype),
)
def dgelu_cast_transpose_bgrad_fp8(
grad_output: paddle.Tensor,
gelu_input: paddle.Tensor,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""
Fused dgelu + cast / transpose / reduce the result of
the GELU backward along the first dimension
"""
cast_dgelu, transpose_dgelu, dbias, _, _ = tex.te_cast_transpose_bgrad_dgelu(
grad_output,
gelu_input,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
int(otype),
)
return cast_dgelu, transpose_dgelu, dbias
def layernorm_fwd_fp8(
inp: paddle.Tensor,
weight: paddle.Tensor,
bias: paddle.Tensor,
eps: float,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""LayerNorm with FP8 output"""
out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8(
inp,
weight,
bias,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
eps,
fp8_tensor.value,
int(otype),
sm_margin,
zero_centered_gamma,
)
return out, mu, rsigma
def layernorm_fwd(
inp: paddle.Tensor,
weight: paddle.Tensor,
bias: paddle.Tensor,
eps: float,
otype: tex.DType,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 LayerNorm forward"""
return tex.te_layernorm_fwd(inp, weight, bias, eps, int(otype), sm_margin, zero_centered_gamma)
def layernorm_bwd(
dz: paddle.Tensor,
x: paddle.Tensor,
mu: paddle.Tensor,
rsigma: paddle.Tensor,
gamma: paddle.Tensor,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 LayerNorm backward"""
return tex.te_layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma)
def rmsnorm_fwd(
inp: paddle.Tensor,
weight: paddle.Tensor,
eps: float,
otype: tex.DType,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm forward"""
return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin, zero_centered_gamma)
def rmsnorm_fwd_fp8(
inp: paddle.Tensor,
weight: paddle.Tensor,
eps: float,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""RMSNorm with FP8 output"""
out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(
inp,
weight,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
eps,
fp8_tensor.value,
int(otype),
sm_margin,
zero_centered_gamma,
)
return out, rsigma
def rmsnorm_bwd(
dz: paddle.Tensor,
x: paddle.Tensor,
rsigma: paddle.Tensor,
gamma: paddle.Tensor,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm backward"""
return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma)
def mask_to_cu_seqlens(
mask: paddle.Tensor,
need_kv: bool = False,
) -> paddle.Tensor:
"""Convert mask to cu_seqlens"""
# mask shape: [b, 1, s_q, s_kv]
if get_global_fp8_state().is_cudagraph_enabled():
raise RuntimeError("mask_to_cu_seqlens is not supported with cuda graphs.")
q_seqlen, kv_seqlen = mask.shape[2], mask.shape[3]
q_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32)
q_cu_seqlens[0] = 0
kv_cu_seqlens = None
if need_kv:
kv_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32)
kv_cu_seqlens[0] = 0
tex.mask_to_cu_seqlens(mask, q_cu_seqlens, kv_cu_seqlens, q_seqlen, kv_seqlen, need_kv)
return q_cu_seqlens, kv_cu_seqlens
def fused_attn_fwd_qkvpacked(
qkv: paddle.Tensor,
cu_seqlens: paddle.Tensor,
is_training: bool,
max_seqlen: int,
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Bias: paddle.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "bs3hd",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed QKV input"""
assert qkv_dtype in (
tex.DType.kBFloat16,
tex.DType.kFloat16,
), "Only support bf16/fp16 for fused attention."
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1]
h = qkv.shape[3]
d = qkv.shape[4]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert Bias.shape == [
1,
h,
max_seqlen,
max_seqlen,
], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert Bias.dtype == qkv.dtype, "bias tensor must be in the same dtype as qkv."
assert (
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (
max_seqlen * max_seqlen + BACKEND_F16m512_THREADS_PER_CTA - 1
) // BACKEND_F16m512_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero:
out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype)
else:
out = paddle.empty(shape=[b, max_seqlen, h, d], dtype=qkv.dtype)
if is_training:
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype="float32")
else:
raise ValueError("Unsupported fused attention backend.")
else:
softmax_aux = None
rng_state = paddle.empty(
shape=[
2,
],
dtype=paddle.int64,
)
# execute kernel
tex.te_fused_attn_fwd_qkvpacked(
qkv,
cu_seqlens,
Bias,
out,
softmax_aux,
rng_state,
b,
h,
d,
total_seqs,
max_seqlen,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
rng_elts_per_thread,
)
return out, softmax_aux, rng_state
def fused_attn_bwd_qkvpacked(
qkv: paddle.Tensor,
cu_seqlens: paddle.Tensor,
rng_state: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
max_seqlen: int,
qkv_dtype: tex.DType,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "bs3hd",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
deterministic: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention BWD for packed QKV input"""
assert qkv_dtype in (
tex.DType.kBFloat16,
tex.DType.kFloat16,
), "Only support bf16/fp16 for fused attention."
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1]
h = qkv.shape[3]
d = qkv.shape[4]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero:
dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype)
else:
dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype)
if bias_type != "no_bias":
if qkv_format == "thd":
dbias = paddle.zero(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
else:
dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype)
else:
dbias = None
# execute kernel
dqkv, dbias = tex.te_fused_attn_bwd_qkvpacked(
qkv,
cu_seqlens,
o,
d_o,
softmax_aux,
dqkv,
dbias,
rng_state,
b,
h,
d,
total_seqs,
max_seqlen,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
deterministic,
)
return dqkv, dbias
def fused_attn_fwd_kvpacked(
q: paddle.Tensor,
kv: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
is_training: bool,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Bias: paddle.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "bshd_bs2hd",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed KV input"""
assert qkv_dtype in (
tex.DType.kBFloat16,
tex.DType.kFloat16,
), "Only support bf16/fp16 for fused attention."
assert (
cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
b = cu_seqlens_q.shape[0] - 1
total_seqs_q = q.shape[0] * q.shape[1]
total_seqs_kv = kv.shape[0] * kv.shape[1]
h = q.shape[2]
d = q.shape[3]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert Bias.shape == [
1,
h,
max_seqlen_q,
max_seqlen_kv,
], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as q and kv."
assert (
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (
max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1
) // BACKEND_F16m512_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero:
out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
else:
out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)
if is_training:
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32")
else:
raise ValueError("Unsupported fused attention backend.")
else:
softmax_aux = None
rng_state = paddle.empty(
shape=[
2,
],
dtype=paddle.int64,
)
# execute kernel
tex.te_fused_attn_fwd_kvpacked(
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
Bias,
out,
softmax_aux,
rng_state,
b,
h,
d,
total_seqs_q,
total_seqs_kv,
max_seqlen_q,
max_seqlen_kv,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
rng_elts_per_thread,
)
return out, softmax_aux, rng_state
def fused_attn_bwd_kvpacked(
q: paddle.Tensor,
kv: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
rng_state: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "bshd_bs2hd",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
deterministic: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention BWD for packed KV input"""
assert qkv_dtype in (
tex.DType.kBFloat16,
tex.DType.kFloat16,
), "Only support bf16/fp16 for fused attention."
assert (
cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
b = cu_seqlens_q.shape[0] - 1
total_seqs_q = q.shape[0] * q.shape[1]
total_seqs_kv = kv.shape[0] * kv.shape[1]
h = q.shape[2]
d = q.shape[3]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype)
else:
dq = paddle.empty(shape=q.shape, dtype=q.dtype)
dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype)
if bias_type != "no_bias":
if qkv_format == "thd":
dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = None
# execute kernel
tex.te_fused_attn_bwd_kvpacked(
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
o,
d_o,
softmax_aux,
dq,
dkv,
dbias,
rng_state,
b,
h,
d,
total_seqs_q,
total_seqs_kv,
max_seqlen_q,
max_seqlen_kv,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
deterministic,
)
return dq, dkv, dbias
def fused_attn_fwd(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
is_training: bool,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Bias: paddle.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "bshd_bshd_bshd",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for unpacked QKV input"""
assert qkv_dtype in (
tex.DType.kBFloat16,
tex.DType.kFloat16,
), "Only support bf16/fp16 for fused attention."
assert (
cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
assert (
qkv_layout == "bshd_bshd_bshd"
), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
b = cu_seqlens_q.shape[0] - 1
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert Bias.shape == [
1,
h,
max_seqlen_q,
max_seqlen_kv,
], "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as qkv."
assert (
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
rng_elts_per_thread = None
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (
max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1
) // BACKEND_F16m512_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero:
out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
else:
out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)
if is_training:
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32")
else:
raise ValueError("Unsupported fused attention backend.")
else:
softmax_aux = None
rng_state = paddle.empty(
shape=[
2,
],
dtype=paddle.int64,
)
# execute kernel
tex.te_fused_attn_fwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
Bias,
out,
softmax_aux,
rng_state,
b,
h,
d,
max_seqlen_q,
max_seqlen_kv,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
rng_elts_per_thread,
)
return out, softmax_aux, rng_state
def fused_attn_bwd(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
rng_state: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "bshd_bshd_bshd",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
deterministic: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention BWD for packed KV input"""
assert qkv_dtype in (
tex.DType.kBFloat16,
tex.DType.kFloat16,
), "Only support bf16/fp16 for fused attention."
assert (
cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
assert (
qkv_layout == "bshd_bshd_bshd"
), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
b = cu_seqlens_q.shape[0] - 1
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (
fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
set_zero = True
if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype)
dv = paddle.full(shape=v.shape, fill_value=0, dtype=v.dtype)
else:
dq = paddle.empty(shape=q.shape, dtype=q.dtype)
dk = paddle.empty(shape=k.shape, dtype=k.dtype)
dv = paddle.empty(shape=v.shape, dtype=v.dtype)
if bias_type != "no_bias":
if qkv_format == "thd":
dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = None
# execute kernel
tex.te_fused_attn_bwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
o,
d_o,
softmax_aux,
dq,
dk,
dv,
dbias,
rng_state,
b,
h,
d,
max_seqlen_q,
max_seqlen_kv,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
deterministic,
)
return dq, dk, dv, dbias
def scaled_softmax_forward(
inp: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
"""scaled softmax forward"""
return tex.te_scaled_softmax_forward(inp, scale_factor)
def scaled_softmax_backward(
out_grad: paddle.Tensor,
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
"""scaled softmax backward"""
tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
def scaled_masked_softmax_forward(
inp: paddle.Tensor,
mask: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
"""scaled masked softmax forward"""
return tex.te_scaled_masked_softmax_forward(inp, mask, scale_factor)
def scaled_masked_softmax_backward(
out_grad: paddle.Tensor,
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
"""scaled masked softmax backward"""
tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
def scaled_upper_triang_masked_softmax_forward(
inp: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
"""scaled upper triang masked softmax forward"""
return tex.te_scaled_upper_triang_masked_softmax_forward(inp, scale_factor)
def scaled_upper_triang_masked_softmax_backward(
out_grad: paddle.Tensor,
softmax_results: paddle.Tensor,
scale_factor: float,
) -> paddle.Tensor:
"""scaled upper triang masked softmax backward"""
tex.te_scaled_upper_triang_masked_softmax_backward(out_grad, softmax_results, scale_factor)
return out_grad
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "common.h"
namespace transformer_engine {
namespace paddle_ext {
TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector<size_t> &shape,
const DType type) {
return TensorWrapper(const_cast<void *>(data_ptr), shape, type);
}
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type) {
return TensorWrapper(data_ptr, shape, type);
}
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
void *amax_ptr, void *scale_ptr, void *scale_inv_ptr) {
return TensorWrapper(data_ptr, shape, type, reinterpret_cast<float *>(amax_ptr),
reinterpret_cast<float *>(scale_ptr),
reinterpret_cast<float *>(scale_inv_ptr));
}
TensorWrapper MakeNvteTensor(paddle::Tensor &tensor) { // NOLINT
return MakeNvteTensor(tensor.data(), GetShapeArray(tensor), Paddle2NvteDType(tensor.dtype()));
}
TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor) {
return MakeNvteTensor(const_cast<void *>(tensor.data()), GetShapeArray(tensor),
Paddle2NvteDType(tensor.dtype()));
}
paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place,
bool init_to_zeros) {
auto size = shape.ndim;
if (size == 2 && init_to_zeros) {
return paddle::zeros({static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
Nvte2PaddleDType(type), place);
} else if (size == 2) {
return paddle::empty({static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
Nvte2PaddleDType(type), place);
} else if (size == 1 && init_to_zeros) {
return paddle::zeros({static_cast<int64_t>(shape.data[0])}, Nvte2PaddleDType(type), place);
} else if (size == 1) {
return paddle::empty({static_cast<int64_t>(shape.data[0])}, Nvte2PaddleDType(type), place);
}
NVTE_CHECK(false, "Should never reach here! func: AllocateSpace");
}
// MHA utils
// convert QKV layout to enum
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout) {
static const std::unordered_map<std::string, NVTE_QKV_Layout> layout_map = {
{"sb3hd", NVTE_QKV_Layout::NVTE_SB3HD},
{"sbh3d", NVTE_QKV_Layout::NVTE_SBH3D},
{"sbhd_sb2hd", NVTE_QKV_Layout::NVTE_SBHD_SB2HD},
{"sbhd_sbh2d", NVTE_QKV_Layout::NVTE_SBHD_SBH2D},
{"sbhd_sbhd_sbhd", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD},
{"bs3hd", NVTE_QKV_Layout::NVTE_BS3HD},
{"bsh3d", NVTE_QKV_Layout::NVTE_BSH3D},
{"bshd_bs2hd", NVTE_QKV_Layout::NVTE_BSHD_BS2HD},
{"bshd_bsh2d", NVTE_QKV_Layout::NVTE_BSHD_BSH2D},
{"bshd_bshd_bshd", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD},
{"t3hd", NVTE_QKV_Layout::NVTE_T3HD},
{"th3d", NVTE_QKV_Layout::NVTE_TH3D},
{"thd_t2hd", NVTE_QKV_Layout::NVTE_THD_T2HD},
{"thd_th2d", NVTE_QKV_Layout::NVTE_THD_TH2D},
{"thd_thd_thd", NVTE_QKV_Layout::NVTE_THD_THD_THD},
};
auto it = layout_map.find(qkv_layout);
if (it != layout_map.end()) {
return it->second;
} else {
NVTE_ERROR("Invalid QKV layout string: " + qkv_layout);
}
}
} // namespace paddle_ext
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <cublasLt.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
#include <cstdlib>
#include <vector>
#include "common/util/logging.h"
#include "paddle/extension.h"
#include "paddle/phi/backends/all_context.h"
namespace transformer_engine {
namespace paddle_ext {
// Paddle Tensor Utils
template <typename T>
inline const void *GetDataPtr(const paddle::Tensor &x, int64_t index) {
if (index < 0 || index >= x.numel()) {
NVTE_ERROR("Index out of bound");
}
return reinterpret_cast<const void *>(x.data<T>() + static_cast<size_t>(index));
}
template <typename T>
inline void *GetDataPtr(paddle::Tensor &x, int64_t index) { // NOLINT
if (index < 0 || index >= x.numel()) {
NVTE_ERROR("Index out of bound");
}
return reinterpret_cast<void *>(x.data<T>() + static_cast<size_t>(index));
}
template <typename T>
inline const void *GetOptionalDataPtr(const paddle::optional<paddle::Tensor> &x, int64_t index) {
return x ? GetDataPtr<T>(*x, index) : nullptr;
}
template <typename T>
inline void *GetOptionalDataPtr(paddle::optional<paddle::Tensor> &x, int64_t index) { // NOLINT
return x ? GetDataPtr<T>(*x, index) : nullptr;
}
inline const void *GetOptionalDataPtr(const paddle::optional<paddle::Tensor> &x) {
return x ? x->data() : nullptr;
}
inline void *GetOptionalDataPtr(paddle::optional<paddle::Tensor> &x) { // NOLINT
return x ? x->data() : nullptr;
}
inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) {
std::vector<size_t> shapes;
for (auto dim : x.shape()) {
shapes.push_back(static_cast<size_t>(dim));
}
return shapes;
}
inline std::vector<size_t> GetShapeArray(const paddle::optional<paddle::Tensor> &x) {
if (x) return GetShapeArray(x.get());
return {0};
}
paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place,
bool init_to_zeros = 0);
// DType Utils
inline paddle::DataType Nvte2PaddleDType(DType t) {
switch (t) {
case DType::kInt32:
case DType::kFloat32:
return paddle::DataType::FLOAT32;
case DType::kFloat16:
return paddle::DataType::FLOAT16;
case DType::kBFloat16:
return paddle::DataType::BFLOAT16;
case DType::kByte:
case DType::kFloat8E4M3:
case DType::kFloat8E5M2:
return paddle::DataType::UINT8;
default:
NVTE_ERROR("Invalid type");
}
}
inline DType Paddle2NvteDType(paddle::DataType t) {
switch (t) {
case paddle::DataType::FLOAT16:
return DType::kFloat16;
case paddle::DataType::FLOAT32:
return DType::kFloat32;
case paddle::DataType::BFLOAT16:
return DType::kBFloat16;
case paddle::DataType::BOOL:
return DType::kByte;
case paddle::DataType::UINT8:
return DType::kByte;
case paddle::DataType::INT32:
return DType::kInt32;
case paddle::DataType::INT64:
return DType::kInt64;
default:
NVTE_ERROR("Invalid type");
}
}
inline DType Int2NvteDType(int64_t dtype) {
if (dtype >= 0 && dtype < static_cast<int64_t>(DType::kNumTypes)) {
return static_cast<DType>(dtype);
} else {
NVTE_ERROR("Type not supported.");
}
}
// get the fused attention backend
inline NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim, head_dim, -1, -1);
return fused_attention_backend;
}
// CUDA Utils
class cudaDevicePropertiesManager {
public:
static cudaDevicePropertiesManager &Instance() {
static thread_local cudaDevicePropertiesManager instance;
return instance;
}
int GetMultiProcessorCount() {
if (!prop_queried_) {
int device_id;
NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
cudaGetDeviceProperties(&prop_, device_id);
prop_queried_ = true;
}
return prop_.multiProcessorCount;
}
int GetMajor() {
if (!prop_queried_) {
int device_id;
NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
cudaGetDeviceProperties(&prop_, device_id);
prop_queried_ = true;
}
return prop_.major;
}
private:
bool prop_queried_ = false;
cudaDeviceProp prop_;
};
// NVTE Tensor Utils
TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector<size_t> &shape,
const DType type);
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type);
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
void *amax_ptr, void *scale_ptr, void *scale_inv_ptr);
TensorWrapper MakeNvteTensor(paddle::Tensor &tensor); // NOLINT
TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor);
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout);
} // namespace paddle_ext
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cub/cub.cuh>
#include <map>
#include <vector>
#include "common.h"
#include "common/common.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
namespace transformer_engine {
namespace paddle_ext {
// convert bias type to enum
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) {
if (bias_type == "no_bias") {
return NVTE_Bias_Type::NVTE_NO_BIAS;
} else if (bias_type == "pre_scale_bias") {
return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS;
} else if (bias_type == "post_scale_bias") {
return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
} else {
NVTE_ERROR("Invalid bias type. \n");
}
}
// convert attn mask type to enum
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) {
if (mask_type == "padding") {
return NVTE_Mask_Type::NVTE_PADDING_MASK;
} else if (mask_type == "causal") {
return NVTE_Mask_Type::NVTE_CAUSAL_MASK;
} else if (mask_type == "no_mask") {
return NVTE_Mask_Type::NVTE_NO_MASK;
} else {
NVTE_ERROR("Invalid attention mask type. \n");
}
}
void cast_to_fp8(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &output, // NOLINT
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(input);
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(
output.data(), shape, Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
nvte_fp8_quantize(input_cu.data(), output_cu.data(), input.stream());
}
std::vector<paddle::Tensor> cast_from_fp8(const paddle::Tensor &input,
const paddle::Tensor &scale_inv, int64_t index,
int64_t itype, int64_t otype) {
auto shape = GetShapeArray(input);
auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)));
auto input_cu =
MakeNvteTensor(const_cast<void *>(input.data()), shape, Int2NvteDType(itype), nullptr,
nullptr, const_cast<void *>(GetDataPtr<float>(scale_inv, index)));
auto output_cu = MakeNvteTensor(output);
nvte_fp8_dequantize(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_transpose(const paddle::Tensor &input, int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto output = paddle::empty({input.shape()[1], input.shape()[0]}, input.dtype(), input.place());
auto input_cu = MakeNvteTensor(const_cast<void *>(input.data()), {M, N}, Int2NvteDType(otype));
auto output_cu = MakeNvteTensor(output.data(), {N, M}, Int2NvteDType(otype));
nvte_transpose(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
void te_cast_transpose(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &output_cast, // NOLINT
paddle::Tensor &output_transpose, // NOLINT
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto input_cu = MakeNvteTensor(input);
void *amax_data = GetDataPtr<float>(amax, index);
void *scale_data = const_cast<void *>(GetDataPtr<float>(scale, index));
void *scale_inv_data = GetDataPtr<float>(scale_inv, index);
auto output_cast_cu = MakeNvteTensor(output_cast.data(), {M, N}, Int2NvteDType(otype), amax_data,
scale_data, scale_inv_data);
auto output_transpose_cu = MakeNvteTensor(output_transpose.data(), {N, M}, Int2NvteDType(otype),
amax_data, scale_data, scale_inv_data);
nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
input.stream());
}
std::vector<paddle::Tensor> te_cast_transpose_bgrad(const paddle::Tensor &grad_output,
const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(grad_output);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto grad_bias =
paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place());
auto grad_output_cast =
paddle::empty_like(grad_output, Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place());
auto grad_output_transpose =
paddle::empty({grad_output.shape()[1], grad_output.shape()[0]},
Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place());
auto input_cu = MakeNvteTensor(grad_output);
void *amax_data = GetDataPtr<float>(amax, index);
void *scale_data = const_cast<void *>(GetDataPtr<float>(scale, index));
void *scale_inv_data = GetDataPtr<float>(scale_inv, index);
auto output_cast_cu = MakeNvteTensor(grad_output_cast.data(), {M, N}, Int2NvteDType(otype),
amax_data, scale_data, scale_inv_data);
auto output_transpose_cu =
MakeNvteTensor(grad_output_transpose.data(), {N, M}, Int2NvteDType(otype), amax_data,
scale_data, scale_inv_data);
auto dbias_cu = MakeNvteTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
dbias_cu.data(), workspace.data(), grad_output.stream());
// Fill workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
dbias_cu.data(), workspace.data(), grad_output.stream());
return {grad_bias, grad_output_cast, grad_output_transpose};
}
void te_gemm(const paddle::Tensor &A, const paddle::optional<paddle::Tensor> &A_scale_inverse,
const paddle::Tensor &B, const paddle::optional<paddle::Tensor> &B_scale_inverse,
const paddle::optional<paddle::Tensor> &bias, paddle::Tensor &D, // NOLINT
paddle::optional<paddle::Tensor> &D_scale, // NOLINT
paddle::optional<paddle::Tensor> &D_amax, // NOLINT
paddle::optional<paddle::Tensor> &pre_gelu_out, paddle::Tensor &workspace, // NOLINT
int64_t A_index, int64_t B_index, int64_t D_index, int64_t A_type, int64_t B_type,
int64_t D_type, int64_t bias_type, bool transa, bool transb, bool grad,
int64_t workspace_size, bool accumulate, bool use_split_accumulator,
int64_t math_sm_count) {
auto te_A = MakeNvteTensor(
const_cast<void *>(A.data()), GetShapeArray(A), Int2NvteDType(A_type), nullptr, nullptr,
const_cast<void *>(GetOptionalDataPtr<float>(A_scale_inverse, A_index)));
auto te_B = MakeNvteTensor(
const_cast<void *>(B.data()), GetShapeArray(B), Int2NvteDType(B_type), nullptr, nullptr,
const_cast<void *>(GetOptionalDataPtr<float>(B_scale_inverse, B_index)));
auto te_D = MakeNvteTensor(D.data(), GetShapeArray(D), Int2NvteDType(D_type),
GetOptionalDataPtr<float>(D_amax, D_index),
GetOptionalDataPtr<float>(D_scale, D_index), nullptr);
auto te_bias = MakeNvteTensor(const_cast<void *>(GetOptionalDataPtr(bias)), GetShapeArray(bias),
Int2NvteDType(bias_type));
DType gelu_dtype = pre_gelu_out ? Paddle2NvteDType(pre_gelu_out->dtype()) : Int2NvteDType(D_type);
auto te_pre_gelu_out =
MakeNvteTensor(GetOptionalDataPtr(pre_gelu_out), GetShapeArray(pre_gelu_out), gelu_dtype);
auto te_workspace =
MakeNvteTensor(workspace.data(), {static_cast<size_t>(workspace_size)}, DType::kByte);
nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(),
transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, A.stream());
}
std::vector<paddle::Tensor> te_gelu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto output = paddle::empty_like(input, Nvte2PaddleDType(DType::kByte), input.place());
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(
output.data(), GetShapeArray(input), Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
nvte_gelu(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_gelu(const paddle::Tensor &input, int64_t otype) {
auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(input), Int2NvteDType(otype));
nvte_gelu(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_swiglu(const paddle::Tensor &input, int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2},
Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(output), Int2NvteDType(otype));
nvte_swiglu(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_swiglu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2},
Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(
output.data(), GetShapeArray(output), Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
nvte_swiglu(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_dswiglu(const paddle::Tensor &grad, const paddle::Tensor &input,
int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input.data(), {M, N}, Paddle2NvteDType(input.dtype()));
auto grad_cu = MakeNvteTensor(grad.data(), {M, N / 2}, Paddle2NvteDType(grad.dtype()));
auto output_cu = MakeNvteTensor(output.data(), {M, N}, Paddle2NvteDType(output.dtype()));
nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_cast_transpose_bgrad_dgelu(const paddle::Tensor &grad_output,
const paddle::Tensor &gelu_input,
const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(grad_output);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
// DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias =
paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place());
auto dgelu = paddle::empty_like(grad_output, Nvte2PaddleDType(DType::kByte), grad_output.place());
auto dgelu_transpose = paddle::empty({grad_output.shape()[1], grad_output.shape()[0]},
Nvte2PaddleDType(DType::kByte), grad_output.place());
void *amax_data = GetDataPtr<float>(amax, index);
void *scale_data = const_cast<void *>(GetDataPtr<float>(scale, index));
void *scale_inv_data = GetDataPtr<float>(scale_inv, index);
TensorWrapper workspace;
auto gelu_input_cu = MakeNvteTensor(gelu_input);
auto input_cu = MakeNvteTensor(grad_output);
auto cast_output_cu = MakeNvteTensor(dgelu.data(), {M, N}, Int2NvteDType(otype), amax_data,
scale_data, scale_inv_data);
auto transposed_output_cu = MakeNvteTensor(dgelu_transpose.data(), {N, M}, Int2NvteDType(otype),
amax_data, scale_data, scale_inv_data);
auto dbias_cu = MakeNvteTensor(grad_bias);
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(), workspace.data(),
grad_output.stream());
// Fill workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(), workspace.data(),
grad_output.stream());
return {dgelu, dgelu_transpose, grad_bias};
}
std::vector<paddle::Tensor> te_layernorm_fwd_fp8(const paddle::Tensor &input,
const paddle::Tensor &weight,
const paddle::Tensor &bias,
const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
float eps, int64_t index, int64_t otype,
int64_t sm_margin, bool zero_centered_gamma) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
size_t N = shape[0];
size_t H = shape[1];
auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto mu = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto rsigma = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
auto gamma_cu = MakeNvteTensor(weight);
auto beta_cu = MakeNvteTensor(bias);
auto z_cu = MakeNvteTensor(
ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
auto mu_cu = MakeNvteTensor(mu);
auto rsigma_cu = MakeNvteTensor(rsigma);
TensorWrapper workspace;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates workspace tensor with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin,
zero_centered_gamma, input.stream());
// Fill workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin,
zero_centered_gamma, input.stream());
return {ln_out, mu, rsigma};
}
std::vector<paddle::Tensor> te_layernorm_fwd(const paddle::Tensor &input,
const paddle::Tensor &weight,
const paddle::Tensor &bias, float eps, int64_t otype,
int64_t sm_margin, bool zero_centered_gamma) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
size_t N = shape[0];
size_t H = shape[1];
auto ln_out = paddle::empty_like(input, input.dtype(), input.place());
auto mu = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto rsigma = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
auto gamma_cu = MakeNvteTensor(weight);
auto beta_cu = MakeNvteTensor(bias);
auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype));
auto mu_cu = MakeNvteTensor(mu);
auto rsigma_cu = MakeNvteTensor(rsigma);
TensorWrapper workspace;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates workspace tensor with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin,
zero_centered_gamma, input.stream());
// Fill workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin,
zero_centered_gamma, input.stream());
return {ln_out, mu, rsigma};
}
std::vector<paddle::Tensor> te_layernorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x,
const paddle::Tensor &mu, const paddle::Tensor &rsigma,
const paddle::Tensor &gamma, int64_t sm_margin,
bool zero_centered_gamma) {
auto dx = paddle::empty_like(x, x.dtype(), x.place());
auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place());
auto dbeta = paddle::empty_like(gamma, gamma.dtype(), gamma.place());
TensorWrapper workspace;
auto dz_cu = MakeNvteTensor(dz);
auto x_cu = MakeNvteTensor(x);
auto mu_cu = MakeNvteTensor(mu);
auto rsigma_cu = MakeNvteTensor(rsigma);
auto gamma_cu = MakeNvteTensor(gamma);
auto dx_cu = MakeNvteTensor(dx);
auto dgamma_cu = MakeNvteTensor(dgamma);
auto dbeta_cu = MakeNvteTensor(dbeta);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates tensors with the required config.
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
num_sm - sm_margin, zero_centered_gamma, dz.stream());
// Alloc space for Tensors.
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// Actual call to bwd kernel.
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(),
num_sm - sm_margin, zero_centered_gamma, dz.stream());
return {dx, dgamma, dbeta};
}
std::vector<paddle::Tensor> te_rmsnorm_fwd(const paddle::Tensor &input,
const paddle::Tensor &weight, float eps, int64_t otype,
int64_t sm_margin, bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm.");
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
size_t N = shape[0];
size_t H = shape[1];
auto ln_out = paddle::empty_like(input, input.dtype(), input.place());
auto rsigma = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
auto gamma_cu = MakeNvteTensor(weight);
auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype));
auto rsigma_cu = MakeNvteTensor(rsigma);
TensorWrapper workspace;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates workspace tensor with the required config
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream());
// Fill workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// Actual call to fwd kernel
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream());
return {ln_out, rsigma};
}
std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input,
const paddle::Tensor &weight,
const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
float eps, int64_t index, int64_t otype,
int64_t sm_margin, bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm.");
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
size_t N = shape[0];
size_t H = shape[1];
auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto rsigma = paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
auto gamma_cu = MakeNvteTensor(weight);
auto z_cu = MakeNvteTensor(
ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
auto rsigma_cu = MakeNvteTensor(rsigma);
TensorWrapper workspace;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates workspace tensor with the required config
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream());
// Fill workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// Actual call to fwd kernel
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream());
return {ln_out, rsigma};
}
std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x,
const paddle::Tensor &rsigma,
const paddle::Tensor &gamma, int64_t sm_margin,
bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm.");
auto dx = paddle::empty_like(x, x.dtype(), x.place());
auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place());
TensorWrapper workspace;
auto dz_cu = MakeNvteTensor(dz);
auto x_cu = MakeNvteTensor(x);
auto rsigma_cu = MakeNvteTensor(rsigma);
auto gamma_cu = MakeNvteTensor(gamma);
auto dx_cu = MakeNvteTensor(dx);
auto dgamma_cu = MakeNvteTensor(dgamma);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// This call populates tensors with the required config.
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), workspace.data(), num_sm - sm_margin, zero_centered_gamma,
dz.stream());
// Alloc space for Tensors.
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// Actual call to bwd kernel.
nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(),
dgamma_cu.data(), workspace.data(), num_sm - sm_margin, zero_centered_gamma,
dz.stream());
return {dx, dgamma};
}
__global__ void set_rng_state(
[[maybe_unused]] unsigned int
identifier, // This is used to relate kernel to cudaGraph nodes please refer to https://github.com/PaddlePaddle/Paddle/pull/60516
std::pair<uint64_t, uint64_t> seed_offset, int64_t *rng_state_ptr) {
rng_state_ptr[0] = static_cast<int64_t>(seed_offset.first);
rng_state_ptr[1] = static_cast<int64_t>(seed_offset.second);
}
void UpdateRandomGenerator(phi::Place place, cudaStream_t stream, int rng_elts_per_thread,
paddle::Tensor &rng_state) {
// extract random number generator seed and offset
const phi::DeviceContext *dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
phi::Generator *gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
int64_t *rng_state_p = static_cast<int64_t *>(rng_state.data());
#if PADDLE_VERSION > 261
auto state_index = gen_cuda->GetStateIndex();
auto parameterSetter = [gen_cuda, state_index,
rng_elts_per_thread](phi::backends::gpu::gpuKernelParams &params) {
// ensure the generator use correct state index
gen_cuda->SetStateIndex(state_index);
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
params.As<std::pair<int64_t, int64_t>>(1) = seed_offset;
};
phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback =
[=](unsigned int id) {
void *functionPtr = reinterpret_cast<void *>(&set_rng_state);
cudaFunction_t cudaFunc;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr));
set_rng_state<<<1, 1, 0, stream>>>(id, seed_offset, rng_state_p);
return cudaFunc;
};
phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter,
cudaKernelCallback);
#else
set_rng_state<<<1, 1, 0, stream>>>(0, seed_offset, rng_state_p);
#endif
}
void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens,
const paddle::optional<paddle::Tensor> &Bias,
paddle::Tensor &O, // NOLINT
paddle::optional<paddle::Tensor> &softmax_aux, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs,
int64_t max_seqlen, bool is_training, float attn_scale,
float p_dropout, const std::string &qkv_layout,
const std::string &bias_type, const std::string &attn_mask_type,
const int64_t qkv_type, int64_t rng_elts_per_thread) {
if (is_training && !softmax_aux) {
NVTE_ERROR("softmax_aux must be provided when training. \n");
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_QKV = MakeNvteTensor(QKV);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32);
te_O = MakeNvteTensor(O);
} else { // TODO: support fp8
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
if ((bias_type != "no_bias") && Bias) {
auto bias_shape = Bias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32);
}
te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
UpdateRandomGenerator(QKV.place(), QKV.stream(), rng_elts_per_thread, rng_state);
auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
auto *output_s = reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
// fused attention BWD with packed QKV
void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens,
const paddle::Tensor &O, const paddle::Tensor &dO,
const paddle::Tensor &softmax_aux,
paddle::Tensor &dQKV, // NOLINT
paddle::optional<paddle::Tensor> &dBias, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs,
int64_t max_seqlen, float attn_scale, float p_dropout,
const std::string &qkv_layout, const std::string &bias_type,
const std::string &attn_mask_type, int64_t qkv_type,
bool deterministic) {
TensorWrapper te_dBias;
if (bias_type != "no_bias" && dBias) {
auto bias_shape = dBias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32);
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_QKV = MakeNvteTensor(QKV);
te_O = MakeNvteTensor(O);
te_dO = MakeNvteTensor(dO);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dP = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dQKV = MakeNvteTensor(dQKV);
} else {
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward into NVTETensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = 2; // 1. softmax_aux 2. rng_state
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *fwd_rng_state = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[1]);
output_s->data.shape =
std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen), static_cast<size_t>(max_seqlen)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data());
fwd_rng_state->data.shape = std::vector<size_t>({2});
fwd_rng_state->data.dptr = const_cast<void *>(rng_state.data());
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens;
te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// create workspace
TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
deterministic, workspace.data(), QKV.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
deterministic, workspace.data(), QKV.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
void te_fused_attn_fwd_kvpacked(
const paddle::Tensor &Q, const paddle::Tensor &KV, const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cu_seqlens_kv, const paddle::optional<paddle::Tensor> &Bias,
paddle::Tensor &O, // NOLINT
paddle::optional<paddle::Tensor> &softmax_aux, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs_q, int64_t total_seqs_kv,
int64_t max_seqlen_q, int64_t max_seqlen_kv, bool is_training, float attn_scale,
float p_dropout, const std::string &qkv_layout, const std::string &bias_type,
const std::string &attn_mask_type, const int64_t qkv_type, int64_t rng_elts_per_thread) {
if (is_training && !softmax_aux) {
NVTE_ERROR("softmax_aux must be provided when training. \n");
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_Q = MakeNvteTensor(
Q.data(),
{static_cast<size_t>(total_seqs_q), static_cast<size_t>(h), static_cast<size_t>(d)},
qkv_dtype);
te_KV = MakeNvteTensor(
KV.data(),
{static_cast<size_t>(total_seqs_kv), 2, static_cast<size_t>(h), static_cast<size_t>(d)},
qkv_dtype);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32);
te_O = MakeNvteTensor(
O.data(),
{static_cast<size_t>(total_seqs_q), static_cast<size_t>(h), static_cast<size_t>(d)},
qkv_dtype);
} else {
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
if ((bias_type != "no_bias") && Bias) {
auto bias_shape = Bias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32);
}
te_cu_seqlens_q =
MakeNvteTensor(cu_seqlens_q.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
te_cu_seqlens_kv =
MakeNvteTensor(cu_seqlens_kv.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
UpdateRandomGenerator(Q.place(), Q.stream(), rng_elts_per_thread, rng_state);
auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
auto *output_s = reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
// fused attention BWD with packed KV
void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &KV,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cu_seqlens_kv, const paddle::Tensor &O,
const paddle::Tensor &dO, const paddle::Tensor &softmax_aux,
paddle::Tensor &dQ, // NOLINT
paddle::Tensor &dKV, // NOLINT
paddle::optional<paddle::Tensor> &dBias, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs_q,
int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv,
float attn_scale, float p_dropout, const std::string &qkv_layout,
const std::string &bias_type, const std::string &attn_mask_type,
int64_t qkv_type, bool deterministic) {
TensorWrapper te_dBias;
if (bias_type != "no_bias" && dBias) {
auto bias_shape = dBias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32);
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_Q = MakeNvteTensor(Q);
te_KV = MakeNvteTensor(KV);
te_O = MakeNvteTensor(O);
te_dO = MakeNvteTensor(dO);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dP = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dQ = MakeNvteTensor(dQ);
te_dKV = MakeNvteTensor(dKV);
} else {
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward into NVTETensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = 2;
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *fwd_rng_state = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[1]);
output_s->data.shape =
std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen_q), static_cast<size_t>(max_seqlen_kv)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data());
fwd_rng_state->data.shape = std::vector<size_t>({2});
fwd_rng_state->data.dptr = const_cast<void *>(rng_state.data());
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
te_cu_seqlens_q =
MakeNvteTensor(cu_seqlens_q.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
te_cu_seqlens_kv =
MakeNvteTensor(cu_seqlens_kv.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// create workspace
TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
-1, -1, deterministic, workspace.data(), Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
-1, -1, deterministic, workspace.data(), Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V,
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv,
const paddle::optional<paddle::Tensor> &Bias,
paddle::Tensor &O, // NOLINT
paddle::optional<paddle::Tensor> &softmax_aux, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv,
bool is_training, float attn_scale, float p_dropout,
const std::string &qkv_layout, const std::string &bias_type,
const std::string &attn_mask_type, const int64_t qkv_type,
int64_t rng_elts_per_thread) {
if (is_training && !softmax_aux) {
NVTE_ERROR("softmax_aux must be provided when training. \n");
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_Q = MakeNvteTensor(Q);
te_K = MakeNvteTensor(K);
te_V = MakeNvteTensor(V);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32);
te_O = MakeNvteTensor(O);
} else { // TODO: support fp8
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
if ((bias_type != "no_bias") && Bias) {
auto bias_shape = Bias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32);
}
te_cu_seqlens_q =
MakeNvteTensor(cu_seqlens_q.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
te_cu_seqlens_kv =
MakeNvteTensor(cu_seqlens_kv.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract random number generator seed and offset
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place());
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
auto stream = Q.stream();
auto rng_state_p = static_cast<int64_t *>(rng_state.data());
#if PADDLE_VERSION > 261
auto state_index = gen_cuda->GetStateIndex();
auto parameterSetter = [gen_cuda, state_index,
rng_elts_per_thread](phi::backends::gpu::gpuKernelParams &params) {
// ensure the generator use correct state index
gen_cuda->SetStateIndex(state_index);
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
params.As<std::pair<int64_t, int64_t>>(1) = seed_offset;
};
phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback =
[=](unsigned int id) {
void *functionPtr = reinterpret_cast<void *>(&set_rng_state);
cudaFunction_t cudaFunc;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr));
set_rng_state<<<1, 1, 0, stream>>>(id, seed_offset, rng_state_p);
return cudaFunc;
};
phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter,
cudaKernelCallback);
#else
set_rng_state<<<1, 1, 0, stream>>>(0, seed_offset, rng_state_p);
#endif
auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
auto *output_s = reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1,
workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V,
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv,
const paddle::Tensor &O, const paddle::Tensor &dO,
const paddle::Tensor &softmax_aux,
paddle::Tensor &dQ, // NOLINT
paddle::Tensor &dK, // NOLINT
paddle::Tensor &dV, // NOLINT
paddle::optional<paddle::Tensor> &dBias, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv,
float attn_scale, float p_dropout, const std::string &qkv_layout,
const std::string &bias_type, const std::string &attn_mask_type,
int64_t qkv_type, bool deterministic) {
TensorWrapper te_dBias;
if (bias_type != "no_bias" && dBias) {
auto bias_shape = dBias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32);
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_Q = MakeNvteTensor(Q);
te_K = MakeNvteTensor(K);
te_V = MakeNvteTensor(V);
te_O = MakeNvteTensor(O);
te_dO = MakeNvteTensor(dO);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dP = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dQ = MakeNvteTensor(dQ);
te_dK = MakeNvteTensor(dK);
te_dV = MakeNvteTensor(dV);
} else {
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward into NVTETensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = 2;
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *fwd_rng_state = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[1]);
output_s->data.shape =
std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen_q), static_cast<size_t>(max_seqlen_kv)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data());
fwd_rng_state->data.shape = std::vector<size_t>({2});
fwd_rng_state->data.dptr = const_cast<void *>(rng_state.data());
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
te_cu_seqlens_q =
MakeNvteTensor(cu_seqlens_q.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
te_cu_seqlens_kv =
MakeNvteTensor(cu_seqlens_kv.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// create workspace
TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
std::vector<paddle::Tensor> te_scaled_softmax_forward(const paddle::Tensor &input,
float scale_factor) {
NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK(
(input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
const int batches = input.shape()[0];
const int attn_heads = input.shape()[1];
const int query_seq_len = input.shape()[2];
const int key_seq_len = input.shape()[3];
NVTE_CHECK(key_seq_len <= 4096);
NVTE_CHECK(query_seq_len > 1);
// Output
auto softmax_results = paddle::empty_like(input, input.dtype(), input.place());
auto input_cu = MakeNvteTensor(input);
auto softmax_results_cu = MakeNvteTensor(softmax_results);
nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor,
input.stream());
return {softmax_results};
}
void te_scaled_softmax_backward(paddle::Tensor &output_grads, // NOLINT
const paddle::Tensor &softmax_results, float scale_factor) {
NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) ||
(output_grads.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) ||
(softmax_results.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = MakeNvteTensor(output_grads);
auto softmax_results_cu = MakeNvteTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(),
output_grads_cu.data(), scale_factor, softmax_results.stream());
}
std::vector<paddle::Tensor> te_scaled_masked_softmax_forward(const paddle::Tensor &input,
const paddle::Tensor &mask,
float scale_factor) {
NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK(mask.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK(
(input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
const int batches = input.shape()[0];
const int pad_batches = mask.shape()[0];
const int attn_heads = input.shape()[1];
const int query_seq_len = input.shape()[2];
const int key_seq_len = input.shape()[3];
NVTE_CHECK(key_seq_len <= 4096);
NVTE_CHECK(query_seq_len > 1);
NVTE_CHECK(pad_batches == 1 || pad_batches == batches);
NVTE_CHECK(mask.shape()[1] == 1);
NVTE_CHECK(mask.shape()[2] == query_seq_len);
NVTE_CHECK(mask.shape()[3] == key_seq_len);
// Output
auto softmax_results = paddle::empty_like(input, input.dtype(), input.place());
auto input_cu = MakeNvteTensor(input);
auto mask_cu = MakeNvteTensor(mask);
auto softmax_results_cu = MakeNvteTensor(softmax_results);
nvte_scaled_masked_softmax_forward(input_cu.data(), mask_cu.data(), softmax_results_cu.data(),
scale_factor, input.stream());
return {softmax_results};
}
void te_scaled_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT
const paddle::Tensor &softmax_results, float scale_factor) {
NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor");
NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) ||
(output_grads.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) ||
(softmax_results.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = MakeNvteTensor(output_grads);
auto softmax_results_cu = MakeNvteTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(),
output_grads_cu.data(), scale_factor, softmax_results.stream());
}
std::vector<paddle::Tensor> te_scaled_upper_triang_masked_softmax_forward(
const paddle::Tensor &input, float scale_factor) {
NVTE_CHECK(input.shape().size() == 3, "expected 3D tensor");
NVTE_CHECK(
(input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
const int attn_batches = input.shape()[0];
const int seq_len = input.shape()[1];
NVTE_CHECK(seq_len <= 2048);
// Output
auto softmax_results = paddle::empty_like(input, input.dtype(), input.place());
auto input_cu = MakeNvteTensor(input);
auto softmax_results_cu = MakeNvteTensor(softmax_results);
nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(),
scale_factor, input.stream());
return {softmax_results};
}
void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT
const paddle::Tensor &softmax_results,
float scale_factor) {
NVTE_CHECK(output_grads.shape().size() == 3, "expected 3D tensor");
NVTE_CHECK(softmax_results.shape().size() == 3, "expected 3D tensor");
NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) ||
(output_grads.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) ||
(softmax_results.dtype() == paddle::DataType::BFLOAT16),
"Only fp16 and bf16 are supported");
NVTE_CHECK(output_grads.shape()[1] == output_grads.shape()[2]);
auto output_grads_cu = MakeNvteTensor(output_grads);
auto softmax_results_cu = MakeNvteTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_upper_triang_masked_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor,
softmax_results.stream());
}
__global__ void UpdateFP8MetaKernel(
[[maybe_unused]] unsigned int
identifier, // This is used to relate kernel to cudaGraph nodes please refer to https://github.com/PaddlePaddle/Paddle/pull/60516
const float *amax, const float *rolled_amax_history, const bool *non_weight_mask,
float *amax_history, float *scale, float *scale_inv, bool update_weight_scale_inv, float margin,
float fp8_max, size_t history_numel, size_t amax_numel) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= history_numel) {
return;
}
amax_history[idx] = rolled_amax_history[idx];
if (idx < amax_numel) {
float sf = (fp8_max / amax[idx]) / powf(2.0f, margin);
float scale_reg = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx];
scale[idx] = scale_reg;
if (update_weight_scale_inv || non_weight_mask[idx]) scale_inv[idx] = 1.0f / scale_reg;
amax_history[idx] = 0.0f;
}
}
constexpr int BLOCK_SIZE = 512;
void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
paddle::Tensor &scale, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
const paddle::Tensor &non_weight_mask, int64_t fp8_dtype,
float margin, const std::string &amax_compute) {
auto amax_history_ = MakeNvteTensor(amax_history);
auto scale_ = MakeNvteTensor(scale);
auto scale_inv_ = MakeNvteTensor(scale_inv);
const auto non_weight_mask_ = MakeNvteTensor(non_weight_mask);
nvte_delayed_scaling_recipe_amax_and_scale_update(
amax_history_.data(), scale_.data(), scale_inv_.data(), non_weight_mask_.data(),
amax_history_.data(), scale_.data(), scale_inv_.data(), amax_compute.c_str(),
static_cast<NVTEDType>(fp8_dtype), margin, amax_history.stream());
}
void amax_and_scale_update_inplace_legacy(
paddle::Tensor &amax_history, // NOLINT
paddle::Tensor &scale, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
const paddle::Tensor &non_weight_mask,
const paddle::optional<paddle::Tensor> &current_step_id_tensor, bool update_weight_scale_inv,
bool fwd_update, float fp8_max, float margin, const std::string &amax_compute) {
#if PADDLE_VERSION > 261
NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent");
paddle::Tensor amax;
if (amax_compute == "max") {
amax = amax_history.max({0});
} else {
amax = amax_history.slice(0, 1);
}
const auto rolled_amax_history = amax_history.roll({-1}, {0});
auto amax_history_numel = amax_history.numel();
auto amax_numel = amax.numel();
size_t num_blocks = (amax_history_numel + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int *current_step_id_ptr =
reinterpret_cast<const int *>(GetOptionalDataPtr(current_step_id_tensor));
auto parameterSetter = [current_step_id_ptr,
fwd_update](phi::backends::gpu::gpuKernelParams &params) {
if (fwd_update) {
int current_step_id = *current_step_id_ptr;
params.As<bool>(7) = (current_step_id == 0);
}
};
const float *amax_ptr = amax.data<float>();
const float *rolled_amax_history_ptr = rolled_amax_history.data<float>();
const bool *non_weight_mask_ptr = non_weight_mask.data<bool>();
float *amax_history_ptr = amax_history.data<float>();
float *scale_ptr = scale.data<float>();
float *scale_inv_ptr = scale_inv.data<float>();
phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback =
[=](unsigned int id) {
void *functionPtr = reinterpret_cast<void *>(&UpdateFP8MetaKernel);
cudaFunction_t cudaFunc;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr));
UpdateFP8MetaKernel<<<num_blocks, BLOCK_SIZE, 0, amax_history.stream()>>>(
id, amax_ptr, rolled_amax_history_ptr, non_weight_mask_ptr, amax_history_ptr, scale_ptr,
scale_inv_ptr, update_weight_scale_inv, margin, fp8_max, amax_history_numel,
amax_numel);
NVTE_CHECK_CUDA(cudaGetLastError());
return cudaFunc;
};
phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter,
cudaKernelCallback);
#else
NVTE_ERROR(
"amax_and_scale_update_inplace_legacy is not supported in old version of PaddlePaddle\n");
#endif
}
void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT
const paddle::Tensor &amax) {
// Copy amax to history[0]
NVTE_CHECK_CUDA(cudaMemcpyAsync(history.data(), amax.data(), amax.numel() * SizeOf(amax.dtype()),
cudaMemcpyDeviceToDevice, amax.stream()));
}
__global__ __launch_bounds__(BLOCK_SIZE) void mask_to_actual_seqlens_kernel(
const bool *mask, int32_t *q_actual_seqlen, int32_t *kv_actual_seqlen, int q_seqlen,
int kv_seqlen, bool need_kv) {
typedef cub::BlockReduce<int, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage q_smem;
__shared__ typename BlockReduce::TempStorage kv_smem;
unsigned int tid = threadIdx.x;
unsigned int batch_offset = blockIdx.x * q_seqlen * kv_seqlen;
// load mask, convert to 1/0, do accumulation
int q = 0, kv = 0;
for (unsigned int q_idx = tid * kv_seqlen; q_idx < q_seqlen * kv_seqlen;
q_idx += BLOCK_SIZE * kv_seqlen) {
q += (mask[q_idx + batch_offset] ? 0 : 1);
}
if (need_kv) {
for (unsigned int kv_idx = tid; kv_idx < kv_seqlen; kv_idx += BLOCK_SIZE) {
kv += (mask[kv_idx + batch_offset] ? 0 : 1);
}
}
__syncthreads();
// compute cub::BlockReduce
int q_sum, kv_sum;
q_sum = BlockReduce(q_smem).Sum(q);
if (need_kv) kv_sum = BlockReduce(kv_smem).Sum(kv);
// write result for this block to global mem
if (tid == 0) {
q_actual_seqlen[blockIdx.x + 1] = q_sum;
if (need_kv) {
kv_actual_seqlen[blockIdx.x + 1] = kv_sum;
}
}
}
__global__ __launch_bounds__(BLOCK_SIZE) void block_prefix_sum_inplace(int32_t *x, int n) {
typedef cub::BlockScan<int32_t, BLOCK_SIZE> BlockScan;
__shared__ typename BlockScan::TempStorage smem;
// +1 to ignore the first element
int i = blockIdx.x * blockDim.x + threadIdx.x + 1;
// load data
int32_t thread_data[1];
thread_data[0] = i < n ? x[i] : 0;
__syncthreads();
// CUB block prefix sum
BlockScan(smem).InclusiveSum(thread_data, thread_data);
__syncthreads();
// write result
if (i < n) {
x[i] = thread_data[0];
}
}
void mask_to_cu_seqlens(const paddle::Tensor &mask,
paddle::Tensor &q_cu_seqlen, // NOLINT
paddle::optional<paddle::Tensor> &kv_cu_seqlen, // NOLINT
int q_seqlen, int kv_seqlen, bool need_kv) {
if (need_kv) {
NVTE_CHECK(GetOptionalDataPtr(kv_cu_seqlen) != nullptr,
"kv_cu_seqlen must be provided when need_kv is true");
}
mask_to_actual_seqlens_kernel<<<mask.shape()[0], BLOCK_SIZE, 0, mask.stream()>>>(
mask.data<bool>(), q_cu_seqlen.data<int32_t>(),
reinterpret_cast<int32_t *>(GetOptionalDataPtr(kv_cu_seqlen)), q_seqlen, kv_seqlen, need_kv);
// q_cu_seqlen shape: [bs+1], assume bs is not too large (<=512), so we can use a single block
// to do prefix sum
NVTE_CHECK(q_cu_seqlen.numel() - 1 <= BLOCK_SIZE, "batch size too large, kernel may fail");
block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>(q_cu_seqlen.data<int32_t>(),
q_cu_seqlen.numel());
if (need_kv) {
block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>(
reinterpret_cast<int32_t *>(GetOptionalDataPtr(kv_cu_seqlen)), kv_cu_seqlen->numel());
}
}
} // namespace paddle_ext
} // namespace transformer_engine
PD_BUILD_OP(te_gemm)
.Inputs({"A", paddle::Optional("A_scale_inverse"), "B", paddle::Optional("B_scale_inverse"),
paddle::Optional("bias"), "_D", paddle::Optional("_D_scale"),
paddle::Optional("_D_amax"), paddle::Optional("_pre_gelu_out"), "_workspace"})
.Outputs({"D", paddle::Optional("D_scale"), paddle::Optional("D_amax"),
paddle::Optional("pre_gelu_out"), "workspace"})
.Attrs({"A_index: int64_t", "B_index: int64_t", "D_index: int64_t", "A_type: int64_t",
"B_type: int64_t", "D_type: int64_t", "bias_type: int64_t", "transa: bool",
"transb: bool", "grad: bool", "workspace_size: int64_t", "accumulate: bool",
"use_split_accumulator: bool", "math_sm_count: int64_t"})
.SetInplaceMap({{"_D", "D"},
{paddle::Optional("_D_scale"), paddle::Optional("D_scale")},
{paddle::Optional("_D_amax"), paddle::Optional("D_amax")},
{paddle::Optional("_pre_gelu_out"), paddle::Optional("pre_gelu_out")},
{"_workspace", "workspace"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gemm));
PD_BUILD_OP(cast_to_fp8)
.Inputs({"Input", "Scale", "_Output", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Amax", "ScaleInv"})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetInplaceMap({{"_Output", "Output"}, {"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_to_fp8));
PD_BUILD_OP(cast_from_fp8)
.Inputs({"Input", "ScaleInv"})
.Outputs({"Output"})
.Attrs({"index: int64_t", "itype: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_from_fp8));
PD_BUILD_OP(te_transpose)
.Inputs({"Input"})
.Outputs({"Output"})
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_transpose));
PD_BUILD_OP(te_cast_transpose)
.Inputs({"Input", "Scale", "_CastedOutput", "_TransposedOutput", "_Amax", "_ScaleInv"})
.Outputs({"CastedOutput", "TransposedOutput", "Amax", "ScaleInv"})
.SetInplaceMap({{"_CastedOutput", "CastedOutput"},
{"_TransposedOutput", "TransposedOutput"},
{"_Amax", "Amax"},
{"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose));
PD_BUILD_OP(te_cast_transpose_bgrad)
.Inputs({"GradOutput", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"dBias", "CastedOutput", "TransposedOutput", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose_bgrad));
PD_BUILD_OP(te_gelu_fp8)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu_fp8));
PD_BUILD_OP(te_gelu)
.Inputs({"Input"})
.Outputs({"Output"})
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu));
PD_BUILD_OP(te_swiglu)
.Inputs({"Input"})
.Outputs({"Output"})
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu));
PD_BUILD_OP(te_swiglu_fp8)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu_fp8));
PD_BUILD_OP(te_dswiglu)
.Inputs({"Grad", "Input"})
.Outputs({"Output"})
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_dswiglu));
PD_BUILD_OP(te_cast_transpose_bgrad_dgelu)
.Inputs({"GradOutput", "GeluInput", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"CastedDgelu", "TransposedDgelu", "Dbias", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose_bgrad_dgelu));
PD_BUILD_OP(te_layernorm_fwd_fp8)
.Inputs({"Input", "Weight", "Bias", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Mu", "Rsigma", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t",
"zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_fwd_fp8));
PD_BUILD_OP(te_layernorm_fwd)
.Inputs({"Input", "Weight", "Bias"})
.Outputs({"Output", "Mu", "Rsigma"})
.Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_fwd));
PD_BUILD_OP(te_layernorm_bwd)
.Inputs({"Dz", "X", "Mu", "Rsigma", "Gamma"})
.Outputs({"Dx", "Dgamma", "Dbeta"})
.Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_bwd));
PD_BUILD_OP(te_rmsnorm_fwd)
.Inputs({"Input", "Weight"})
.Outputs({"Output", "InvVariance"})
.Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd));
PD_BUILD_OP(te_rmsnorm_fwd_fp8)
.Inputs({"Input", "Weight", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "InvVariance", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t",
"zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd_fp8));
PD_BUILD_OP(te_rmsnorm_bwd)
.Inputs({"Dz", "X", "Rsigma", "Gamma"})
.Outputs({"Dx", "Dgamma"})
.Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_bwd));
PD_BUILD_OP(te_fused_attn_fwd_qkvpacked)
.Inputs({"QKV", "cu_seqlens", paddle::Optional("Bias"), "_O", paddle::Optional("_softmax_aux"),
"_rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t",
"is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t",
"rng_elts_per_thread: int64_t"})
.SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")},
{"_rng_state", "rng_state"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked));
PD_BUILD_OP(te_fused_attn_bwd_qkvpacked)
.Inputs({"QKV", "cu_seqlens", "O", "dO", "softmax_aux", "_dQKV", paddle::Optional("_dBias"),
"rng_state"})
.Outputs({"dQKV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t",
"attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t",
"deterministic: bool"})
.SetInplaceMap({{"_dQKV", "dQKV"}, {paddle::Optional("_dBias"), paddle::Optional("dBias")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_qkvpacked));
PD_BUILD_OP(te_fused_attn_fwd_kvpacked)
.Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O",
paddle::Optional("_softmax_aux"), "_rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t",
"total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t",
"is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t",
"rng_elts_per_thread: int64_t"})
.SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")},
{"_rng_state", "rng_state"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_kvpacked));
PD_BUILD_OP(te_fused_attn_bwd_kvpacked)
.Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dKV",
paddle::Optional("_dBias"), "rng_state"})
.Outputs({"dQ", "dKV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t",
"total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t",
"attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t",
"deterministic: bool"})
.SetInplaceMap({{"_dQ", "dQ"},
{"_dKV", "dKV"},
{paddle::Optional("_dBias"), paddle::Optional("dBias")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_kvpacked));
PD_BUILD_OP(te_fused_attn_fwd)
.Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O",
paddle::Optional("_softmax_aux"), "_rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t",
"max_seqlen_kv: int64_t", "is_training: bool", "attn_scale: float", "p_dropout: float",
"qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string",
"qkv_type: int64_t", "rng_elts_per_thread: int64_t"})
.SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")},
{"_rng_state", "rng_state"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd));
PD_BUILD_OP(te_fused_attn_bwd)
.Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dK",
"_dV", paddle::Optional("_dBias"), "rng_state"})
.Outputs({"dQ", "dK", "dV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t",
"max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float",
"qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string",
"qkv_type: int64_t", "deterministic: bool"})
.SetInplaceMap({{"_dQ", "dQ"},
{"_dK", "dK"},
{"_dV", "dV"},
{paddle::Optional("_dBias"), paddle::Optional("dBias")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd));
PD_BUILD_OP(te_scaled_softmax_forward)
.Inputs({"input"})
.Outputs({"softmax_results"})
.Attrs({"scale_factor: float"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_softmax_forward));
PD_BUILD_OP(te_scaled_softmax_backward)
.Inputs({"out_grad_", "softmax_results"})
.Outputs({"out_grad"})
.Attrs({"scale_factor: float"})
.SetInplaceMap({{"out_grad_", "out_grad"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_softmax_backward));
PD_BUILD_OP(te_scaled_masked_softmax_forward)
.Inputs({"input", "mask"})
.Outputs({"softmax_results"})
.Attrs({"scale_factor: float"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_masked_softmax_forward));
PD_BUILD_OP(te_scaled_masked_softmax_backward)
.Inputs({"out_grad_", "softmax_results"})
.Outputs({"out_grad"})
.Attrs({"scale_factor: float"})
.SetInplaceMap({{"out_grad_", "out_grad"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_masked_softmax_backward));
PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_forward)
.Inputs({"input"})
.Outputs({"softmax_results"})
.Attrs({"scale_factor: float"})
.SetKernelFn(
PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_forward));
PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward)
.Inputs({"out_grad_", "softmax_results"})
.Outputs({"out_grad"})
.Attrs({"scale_factor: float"})
.SetInplaceMap({{"out_grad_", "out_grad"}})
.SetKernelFn(
PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward));
PD_BUILD_OP(amax_and_scale_update_inplace_legacy)
.Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask",
paddle::Optional("current_step_id_tensor")})
.Outputs({"amax_history", "scale", "scale_inv"})
.SetInplaceMap({{"_amax_history", "amax_history"},
{"_scale", "scale"},
{"_scale_inv", "scale_inv"}})
.Attrs({"update_weight_scale_inv: bool", "fwd_update: bool", "fp8_max: float", "margin: float",
"amax_compute: std::string"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace_legacy));
PD_BUILD_OP(amax_and_scale_update_inplace)
.Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask"})
.Outputs({"amax_history", "scale", "scale_inv"})
.SetInplaceMap({{"_amax_history", "amax_history"},
{"_scale", "scale"},
{"_scale_inv", "scale_inv"}})
.Attrs({"fp8_dtype: int64_t", "margin: float", "amax_compute: std::string"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace));
PD_BUILD_OP(update_latest_amax_history_inplace)
.Inputs({"_history", "amax"})
.Outputs({"history"})
.SetInplaceMap({{"_history", "history"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::update_latest_amax_history_inplace));
PD_BUILD_OP(mask_to_cu_seqlens)
.Inputs({"mask", "_q_cu_seqlen", paddle::Optional("_kv_cu_seqlen")})
.Outputs({"q_cu_seqlen", paddle::Optional("kv_cu_seqlen")})
.Attrs({"q_seqlen: int", "kv_seqlen: int", "need_kv: bool"})
.SetInplaceMap({{"_q_cu_seqlen", "q_cu_seqlen"},
{paddle::Optional("_kv_cu_seqlen"), paddle::Optional("kv_cu_seqlen")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::mask_to_cu_seqlens));
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "common.h"
namespace transformer_engine {
namespace paddle_ext {
size_t get_cublasLt_version() { return cublasLtGetVersion(); }
PYBIND11_MODULE(transformer_engine_paddle, m) {
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
m.def("get_nvte_qkv_layout", &get_nvte_qkv_layout, "Get qkv layout enum by the string");
// Data structures
py::enum_<DType>(m, "DType", py::module_local())
.value("kByte", DType::kByte)
.value("kInt32", DType::kInt32)
.value("kFloat32", DType::kFloat32)
.value("kFloat16", DType::kFloat16)
.value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2);
py::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type")
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
py::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type")
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD)
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D)
.value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
.value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D)
.value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD)
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D)
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D)
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);
py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", py::module_local())
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8)
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend);
}
} // namespace paddle_ext
} // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment