Unverified Commit 989a53a0 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Add FP8 fused attention (#155)



* Add FP8 fused attention to TE for PyTorch
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add license for cudnn-frontend, modify installation requirements, and refactor some headers for aesthetics
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add c api docs for fused attention
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add exception for unsupported precision/sequence length combinations
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix installation requirement for non fused attn use cases
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix docs for fused-attn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* prefix enums with NVTE_ and replace old MHA_Matrix with NVTE_QKV_Matrix
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* minor fixes based on PR comments
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix description for kvpacked fwd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix description of Bias in C api
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* minor fixes for cudnn requirement and description for QKV tensors
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix QKV layout description and support matrix for C api
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add asserts to cpp_extensions for qkv layout/bias type/attn mask type
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix typo precision
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c3407300
...@@ -17,6 +17,8 @@ jobs: ...@@ -17,6 +17,8 @@ jobs:
steps: steps:
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build' - name: 'Build'
run: | run: |
mkdir -p wheelhouse && \ mkdir -p wheelhouse && \
...@@ -41,6 +43,8 @@ jobs: ...@@ -41,6 +43,8 @@ jobs:
steps: steps:
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build' - name: 'Build'
run: | run: |
pip install ninja pybind11 && \ pip install ninja pybind11 && \
...@@ -66,6 +70,8 @@ jobs: ...@@ -66,6 +70,8 @@ jobs:
steps: steps:
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build' - name: 'Build'
run: | run: |
pip install ninja pybind11 && \ pip install ninja pybind11 && \
......
[submodule "3rdparty/googletest"] [submodule "3rdparty/googletest"]
path = 3rdparty/googletest path = 3rdparty/googletest
url = https://github.com/google/googletest.git url = https://github.com/google/googletest.git
[submodule "3rdparty/cudnn-frontend"]
path = 3rdparty/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
Subproject commit e7f64390e9bb4a3db622ffe11c973834f572b609
...@@ -138,3 +138,25 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER ...@@ -138,3 +138,25 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
========================
cudnn-frontend
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
..
Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
fused_attn.h
============
.. doxygenfile:: fused_attn.h
...@@ -17,6 +17,7 @@ directly from C/C++, without Python. ...@@ -17,6 +17,7 @@ directly from C/C++, without Python.
activation.h <activation> activation.h <activation>
cast.h <cast> cast.h <cast>
gemm.h <gemm> gemm.h <gemm>
fused_attn.h <fused_attn>
layer_norm.h <layer_norm> layer_norm.h <layer_norm>
softmax.h <softmax> softmax.h <softmax>
transformer_engine.h <transformer_engine> transformer_engine.h <transformer_engine>
......
...@@ -14,6 +14,8 @@ Prerequisites ...@@ -14,6 +14,8 @@ Prerequisites
1. Linux x86_64 1. Linux x86_64
2. `CUDA 11.8 <https://developer.nvidia.com/cuda-downloads>`__ 2. `CUDA 11.8 <https://developer.nvidia.com/cuda-downloads>`__
3. |driver link|_ supporting CUDA 11.8 or later. 3. |driver link|_ supporting CUDA 11.8 or later.
4. `cuDNN 8 <https://developer.nvidia.com/cudnn>`__ or later.
5. For FP8 fused attention, `CUDA 12.1 <https://developer.nvidia.com/cuda-downloads>`__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9 <https://developer.nvidia.com/cudnn>`__ or later.
Transformer Engine in NGC Containers Transformer Engine in NGC Containers
......
...@@ -105,6 +105,7 @@ framework = os.environ.get("NVTE_FRAMEWORK", "pytorch") ...@@ -105,6 +105,7 @@ framework = os.environ.get("NVTE_FRAMEWORK", "pytorch")
include_dirs = [ include_dirs = [
"transformer_engine/common/include", "transformer_engine/common/include",
"transformer_engine/pytorch/csrc", "transformer_engine/pytorch/csrc",
"3rdparty/cudnn-frontend/include",
] ]
if NVTE_WITH_USERBUFFERS: if NVTE_WITH_USERBUFFERS:
if MPI_HOME: if MPI_HOME:
......
...@@ -42,6 +42,7 @@ const std::string &typeName(DType type) { ...@@ -42,6 +42,7 @@ const std::string &typeName(DType type) {
static const std::unordered_map<DType, std::string> name_map = { static const std::unordered_map<DType, std::string> name_map = {
{DType::kByte, "byte"}, {DType::kByte, "byte"},
{DType::kInt32, "int32"}, {DType::kInt32, "int32"},
{DType::kInt64, "int64"},
{DType::kFloat32, "float32"}, {DType::kFloat32, "float32"},
{DType::kFloat16, "float16"}, {DType::kFloat16, "float16"},
{DType::kBFloat16, "bfloat16"}, {DType::kBFloat16, "bfloat16"},
......
...@@ -44,6 +44,7 @@ struct BytesToType<8> { ...@@ -44,6 +44,7 @@ struct BytesToType<8> {
using byte = uint8_t; using byte = uint8_t;
using int32 = int32_t; using int32 = int32_t;
using int64 = int64_t;
using fp32 = float; using fp32 = float;
using fp16 = half; using fp16 = half;
using bf16 = nv_bfloat16; using bf16 = nv_bfloat16;
...@@ -54,6 +55,7 @@ template <typename T> ...@@ -54,6 +55,7 @@ template <typename T>
struct TypeInfo{ struct TypeInfo{
using types = std::tuple<byte, using types = std::tuple<byte,
int32, int32,
int64,
fp32, fp32,
fp16, fp16,
bf16, bf16,
...@@ -211,6 +213,12 @@ bool isFp8Type(DType type); ...@@ -211,6 +213,12 @@ bool isFp8Type(DType type);
{__VA_ARGS__} \ {__VA_ARGS__} \
} \ } \
break; \ break; \
case DType::kInt64: \
{ \
using type = int64; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat32: \ case DType::kFloat32: \
{ \ { \
using type = float; \ using type = float; \
......
...@@ -19,7 +19,9 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug") ...@@ -19,7 +19,9 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
endif() endif()
list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake/")
find_package(CUDAToolkit REQUIRED cublas nvToolsExt) find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
find_package(CUDNN REQUIRED cudnn)
find_package(Python COMPONENTS Interpreter Development REQUIRED) find_package(Python COMPONENTS Interpreter Development REQUIRED)
include_directories(${PROJECT_SOURCE_DIR}) include_directories(${PROJECT_SOURCE_DIR})
......
add_library(CUDNN::cudnn_all INTERFACE IMPORTED)
find_path(
CUDNN_INCLUDE_DIR cudnn.h
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_INCLUDE_DIRS}
PATH_SUFFIXES include
)
function(find_cudnn_library NAME)
string(TOUPPER ${NAME} UPPERCASE_NAME)
find_library(
${UPPERCASE_NAME}_LIBRARY ${NAME}
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_LIBRARY_DIR}
PATH_SUFFIXES lib64 lib/x64 lib
)
if(${UPPERCASE_NAME}_LIBRARY)
add_library(CUDNN::${NAME} UNKNOWN IMPORTED)
set_target_properties(
CUDNN::${NAME} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}
IMPORTED_LOCATION ${${UPPERCASE_NAME}_LIBRARY}
)
message(STATUS "${NAME} found at ${${UPPERCASE_NAME}_LIBRARY}.")
else()
message(STATUS "${NAME} not found.")
endif()
endfunction()
find_cudnn_library(cudnn)
find_cudnn_library(cudnn_adv_infer)
find_cudnn_library(cudnn_adv_train)
find_cudnn_library(cudnn_cnn_infer)
find_cudnn_library(cudnn_cnn_train)
find_cudnn_library(cudnn_ops_infer)
find_cudnn_library(cudnn_ops_train)
include (FindPackageHandleStandardArgs)
find_package_handle_standard_args(
CUDNN REQUIRED_VARS
CUDNN_INCLUDE_DIR CUDNN_LIBRARY
)
if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY)
message(STATUS "cuDNN: ${CUDNN_LIBRARY}")
message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}")
set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found")
else()
set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Not Found")
endif()
target_include_directories(
CUDNN::cudnn_all
INTERFACE
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>
)
target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv_train
CUDNN::cudnn_ops_train
CUDNN::cudnn_cnn_train
CUDNN::cudnn_adv_infer
CUDNN::cudnn_cnn_infer
CUDNN::cudnn_ops_infer
CUDNN::cudnn
)
...@@ -12,6 +12,9 @@ list(APPEND transformer_engine_SOURCES ...@@ -12,6 +12,9 @@ list(APPEND transformer_engine_SOURCES
transpose/transpose_fusion.cu transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
activation/gelu.cu activation/gelu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/cublaslt_gemm.cu gemm/cublaslt_gemm.cu
layer_norm/ln_api.cpp layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu layer_norm/ln_bwd_semi_cuda_kernel.cu
...@@ -30,9 +33,11 @@ target_include_directories(transformer_engine PUBLIC ...@@ -30,9 +33,11 @@ target_include_directories(transformer_engine PUBLIC
target_link_libraries(transformer_engine PUBLIC target_link_libraries(transformer_engine PUBLIC
CUDA::cublas CUDA::cublas
CUDA::cudart CUDA::cudart
CUDA::nvToolsExt) CUDA::nvToolsExt
CUDNN::cudnn)
target_include_directories(transformer_engine PRIVATE target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CMAKE_SOURCE_DIR}/../3rdparty/cudnn-frontend/include")
# Compiler options # Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/fused_attn.h"
#include "../common.h"
#include "utils.h"
#include "fused_attn_fp8.h"
// NVTE fused attention FWD FP8 with packed QKV
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
const NVTETensor cu_seqlens,
const NVTETensor rng_state,
size_t max_seqlen,
bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_Bias = reinterpret_cast<const Tensor*>(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor*>(S);
Tensor *output_O = reinterpret_cast<Tensor*>(O);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d]
size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = input_QKV->data.shape[2];
size_t d = input_QKV->data.shape[3];
const DType QKV_type = input_QKV->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8900)
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
// FP8 API doesn't use input_Bias, bias_type or attn_mask_type
fused_attn_fwd_fp8_qkvpacked(
b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout,
input_QKV, input_output_S, output_O,
Aux_Output_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n");
#endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n");
} else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
}
}
// NVTE fused attention BWD FP8 with packed QKV
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV,
const NVTETensor dBias,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQKV,
const NVTETensor cu_seqlens,
size_t max_seqlen,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_dBias = reinterpret_cast<const Tensor*>(dBias);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor*>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor*>(dP);
Tensor *output_dQKV = reinterpret_cast<Tensor*>(dQKV);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d]
size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = input_QKV->data.shape[2];
size_t d = input_QKV->data.shape[3];
const DType QKV_type = input_QKV->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8900)
// Aux_CTX_Tensors contain [M, ZInv, rng_state] generated by the forward pass
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
// FP8 API doesn't use input_dBias, bias_type or attn_mask_type
fused_attn_bwd_fp8_qkvpacked(
b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout,
input_QKV, input_O, input_dO,
input_M, input_ZInv,
input_S, input_output_dP,
output_dQKV,
input_cu_seqlens,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n");
#endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n");
} else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
}
}
// NVTE fused attention FWD FP8 with packed KV
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q,
const NVTETensor KV,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV);
const Tensor *input_Bias = reinterpret_cast<const Tensor*>(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor*>(S);
Tensor *output_O = reinterpret_cast<Tensor*>(O);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d]
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[1];
size_t d = input_Q->data.shape[2];
const DType QKV_type = input_Q->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n");
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
}
}
// NVTE fused attention BWD FP8 with packed KV
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q,
const NVTETensor KV,
const NVTETensor dBias,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQ,
NVTETensor dKV,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV);
const Tensor *input_dBias = reinterpret_cast<const Tensor*>(dBias);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor*>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor*>(dP);
Tensor *output_dQ = reinterpret_cast<Tensor*>(dQ);
Tensor *output_dKV = reinterpret_cast<Tensor*>(dKV);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d]
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[1];
size_t d = input_Q->data.shape[2];
const DType QKV_type = input_Q->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n");
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
}
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/fused_attn.h"
#include "../common.h"
#include "utils.h"
#include "fused_attn_fp8.h"
namespace transformer_engine {
namespace fused_attn {
using namespace transformer_engine;
#if (CUDNN_VERSION >= 8900)
std::unordered_map<std::string, int> tensor_name_to_uid = {
{"Q", 1},
{"K", 2},
{"V", 3},
{"O", 4},
{"S", 5},
{"B", 6},
{"DROPOUT_SCALE", 7},
{"S_CONST", 8},
{"MNK_OVERRIDE", 9},
{"dQ", 11},
{"dK", 12},
{"dV", 13},
{"dO", 14},
{"MASK_VAL", 15},
{"dS", 16},
{"O_SEQLEN", 17},
{"M", 18},
{"Z", 19},
{"descaleQ", 20},
{"descaleK", 21},
{"descaleV", 22},
{"descaleS", 23},
{"scaleS", 24},
{"amaxS", 25},
{"amaxO", 26},
{"QKV_RAGGED", 27},
{"O_RAGGED", 28},
{"K_TRANSPOSE", 29},
{"AttnScale", 30},
{"scaleO", 31},
{"Z_INV", 32},
{"descaleO", 33},
{"descaledO", 34},
{"descaledS", 35},
{"descaledQ", 36},
{"descaledK", 37},
{"descaledV", 38},
{"scaledS", 39},
{"scaledQ", 40},
{"scaledK", 41},
{"scaledV", 42},
{"amaxdS", 43},
{"amaxdQ", 44},
{"amaxdK", 45},
{"amaxdV", 46},
{"V_TRANSPOSE", 47},
{"AttnScale_dS_K", 48},
{"AttnScale_dSTranspose_Q", 49},
{"DROPOUT_SCALE_dOVt_OdO", 50},
{"DROPOUT_OFFSET", 51},
{"DROPOUT_SEED", 52},
{"VIRTUAL", 80}
};
bool allowAllConfig(cudnnBackendDescriptor_t engine_config) {
(void)engine_config;
return false;
}
static cudnn_frontend::Tensor tensor_create(
cudnnDataType_t type, int64_t id,
int64_t const * dim, int64_t const * stride,
bool is_virtual, bool is_value) {
int nbDims = 4;
auto tensor_created = cudnn_frontend::TensorBuilder()
.setDim(nbDims, dim)
.setStride(nbDims, stride)
.setId(id)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(type)
.setVirtual(is_virtual)
.setByValue(is_value)
.build();
return tensor_created;
}
static cudnn_frontend::Tensor tensor_create_with_offset(
cudnnDataType_t type, int64_t id,
int64_t const * dim, int64_t const * stride,
bool is_virtual, bool is_value,
std::shared_ptr<cudnn_frontend::Tensor> raggedOffset) {
int nbDims = 4;
auto tensor_created = cudnn_frontend::TensorBuilder()
.setDim(nbDims, dim)
.setStride(nbDims, stride)
.setId(id)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(type)
.setVirtual(is_virtual)
.setByValue(is_value)
.setRaggedOffset(raggedOffset)
.build();
return tensor_created;
}
static cudnn_frontend::PointWiseDesc pw_desc_create(
cudnnDataType_t type, cudnnPointwiseMode_t mode) {
auto pw_desc_created = cudnn_frontend::PointWiseDescBuilder()
.setMode(mode)
.setComputeType(type)
.build();
return pw_desc_created;
}
static cudnn_frontend::Operation unary_pw_op_create(
cudnn_frontend::Tensor const &xDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
static cudnn_frontend::Operation binary_pw_op_create(
cudnn_frontend::Tensor const &xDesc,
cudnn_frontend::Tensor const &bDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setbDesc(bDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
static cudnn_frontend::Operation ternary_pw_op_create(
cudnn_frontend::Tensor const &xDesc,
cudnn_frontend::Tensor const &bDesc,
cudnn_frontend::Tensor const &tDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setbDesc(bDesc)
.settDesc(tDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
static cudnn_frontend::Tensor createAmax(
const std::string& amax_tensor_name,
const cudnn_frontend::Tensor& prevBlockOutputTensor,
std::vector<cudnn_frontend::Operation>* ops) {
int64_t amax_dim[4] = {1, 1, 1, 1};
int64_t amax_stride[4] = {1, 1, 1, 1};
auto amaxTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid[amax_tensor_name],
amax_dim, amax_stride, false, false);
// Define the amax descriptor
auto reductionDesc = cudnn_frontend::ReductionDescBuilder()
.setMathPrecision(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_AMAX)
.build();
// Create a reduction amax Node
auto reduction_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(prevBlockOutputTensor)
.setyDesc(amaxTensor)
.setreductionDesc(reductionDesc)
.build();
ops->push_back(std::move(reduction_op));
return amaxTensor;
}
static cudnn_frontend::Tensor createScale(
const cudnn_frontend::Tensor& prevBlockOutputTensor,
const std::string& scale_tensor_name,
cudnnDataType_t tensorType,
bool isOutputVirtual, bool isScaleByValue,
std::vector<cudnn_frontend::Operation>* ops,
const std::string& output_tensor_name ="") {
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
int64_t output_dim[4];
int64_t output_stride[4];
for (int i = 0; i < 4; i++) {
output_dim[i] = prevBlockOutputTensor.getDim()[i];
output_stride[i] = prevBlockOutputTensor.getStride()[i];
}
auto scaleTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name],
scale_dim, scale_stride, false, isScaleByValue); // is by value
int64_t outputUID = isOutputVirtual ? tensor_name_to_uid["VIRTUAL"]
+ tensor_name_to_uid[scale_tensor_name] + 5000 :
tensor_name_to_uid[output_tensor_name];
auto afterScaleKTensor = tensor_create(
tensorType, outputUID, output_dim,
output_stride, isOutputVirtual, false); // is virtual
// Define the scale descriptor
auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a Scale Node
auto scale_op = binary_pw_op_create(
prevBlockOutputTensor, scaleTensor, afterScaleKTensor, scaleDesc);
ops->push_back(std::move(scale_op));
return afterScaleKTensor;
}
static cudnn_frontend::Tensor createScale(
const cudnn_frontend::Tensor& prevBlockOutputTensor,
const cudnn_frontend::Tensor& scaleTensor,
cudnnDataType_t tensorType,
bool isOutputVirtual, bool isScaleByValue,
std::vector<cudnn_frontend::Operation>* ops,
int UID_offset, const std::string& output_tensor_name ="") {
int64_t output_dim[4];
int64_t output_stride[4];
for (int i = 0; i < 4; i++) {
output_dim[i] = prevBlockOutputTensor.getDim()[i];
output_stride[i] = prevBlockOutputTensor.getStride()[i];
}
int64_t outputUID = isOutputVirtual ?
tensor_name_to_uid["VIRTUAL"] + UID_offset :
tensor_name_to_uid[output_tensor_name];
auto afterScaleTensor = tensor_create(
tensorType, outputUID, output_dim,
output_stride, isOutputVirtual, false); // is virtual
// Define the scale descriptor
auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a Scale Node
auto scale_op = binary_pw_op_create(
prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc);
ops->push_back(std::move(scale_op));
return afterScaleTensor;
}
static cudnn_frontend::Tensor createScaleWithOffset(
const cudnn_frontend::Tensor& prevBlockOutputTensor,
const std::string& scale_tensor_name,
cudnnDataType_t tensorType,
bool isOutputVirtual,
bool isScaleByValue,
std::vector<cudnn_frontend::Operation>* ops,
std::shared_ptr<cudnn_frontend::Tensor> offsetTensor,
const std::string& output_tensor_name ="") {
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
int64_t output_dim[4];
int64_t output_stride[4];
// If output tensor is dQ, dK, or dV, we need to generate QKV interleaved strides
if (output_tensor_name == "dQ" || output_tensor_name == "dK" || output_tensor_name == "dV") {
for (int i = 0; i < 4; i++) {
output_dim[i] = prevBlockOutputTensor.getDim()[i];
}
generateMatrixStrides(output_dim[0], output_dim[1], output_dim[2],
0 /*s_kv = 0 for placeholder*/,
output_dim[3], output_stride,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, NVTE_QKV_Matrix::NVTE_Q_Matrix);
} else {
// Otherwise output dim and stride should be the same as prev block dim and stride
for (int i = 0; i < 4; i++) {
output_dim[i] = prevBlockOutputTensor.getDim()[i];
output_stride[i] = prevBlockOutputTensor.getStride()[i];
}
}
auto scaleTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name],
scale_dim, scale_stride, false, isScaleByValue); // is by value
cudnnDataType_t outputDataType = isOutputVirtual ? CUDNN_DATA_FLOAT : tensorType;
int64_t outputUID = isOutputVirtual ?
tensor_name_to_uid["VIRTUAL"] + tensor_name_to_uid[scale_tensor_name] + 7000 :
tensor_name_to_uid[output_tensor_name];
auto afterScaleTensor = tensor_create_with_offset(
outputDataType, outputUID, output_dim,
output_stride, isOutputVirtual, false, offsetTensor); // is virtual
// Define the scale descriptor
auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a Scale Node
auto scale_op = binary_pw_op_create(
prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc);
ops->push_back(std::move(scale_op));
return afterScaleTensor;
}
static cudnn_frontend::Tensor createSoftmaxForward(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& prevBlockOutputTensor,
bool isTraining) {
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t afterReduction_dim[4] = {b, h, s_q, 1};
int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1};
// max (x) (M tensor)
auto afterMaxReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["M"],
afterReduction_dim, afterReduction_stride,
!isTraining, false); // not virtual if training is true,
// virtual if training is false
// x - max(x)
auto afterSubtractionTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 151,
afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
// e^(x - max(x))
auto afterExponentTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 152,
afterBMM1_dim, afterBMM1_stride, true, false); // is virtual;
// sum (e^(x - max(x))) (Z tensor)
auto zTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["Z"],
afterReduction_dim, afterReduction_stride, true, false); // is virtual
// 1 / sum (e^(x - max(x))) (Z_INV tensor)
auto zInvTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"],
afterReduction_dim, afterReduction_stride,
!isTraining, false); // not virtual if training is true,
// virtual if training is false
// Final softmax output (After exponent * Z_INV)
auto beforeDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 153,
afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
// Define the reduction descriptor
auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
.build();
// Create a reduction max Node
auto reductionMax_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(prevBlockOutputTensor)
.setyDesc(afterMaxReductionTensor)
.setreductionDesc(reductionMaxDesc)
.build();
// Define the subtract descriptor
auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
// Create a subtract Node
auto subtract_op = binary_pw_op_create(
prevBlockOutputTensor, afterMaxReductionTensor,
afterSubtractionTensor, subtractDesc);
// Define the exponent descriptor
auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP);
// Create a exponent Node
auto exponent_op = unary_pw_op_create(
afterSubtractionTensor, afterExponentTensor, exponentDesc);
// Define the reduction descriptor
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
// Create a reduction add Node
auto reductionAdd_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(afterExponentTensor)
.setyDesc(zTensor)
.setreductionDesc(reductionAddDesc)
.build();
// Define the reciprocal descriptor
auto reciprocalDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_RECIPROCAL);
// Create a reciprocal Node
auto reciprocal_op = unary_pw_op_create(zTensor, zInvTensor, reciprocalDesc);
// Define the pw multiply descriptor
auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply Node
auto mutliply_op = binary_pw_op_create(
afterExponentTensor, zInvTensor, beforeDropoutTensor, multiplyDesc);
ops->push_back(std::move(reductionMax_op));
ops->push_back(std::move(subtract_op));
ops->push_back(std::move(exponent_op));
ops->push_back(std::move(reductionAdd_op));
ops->push_back(std::move(reciprocal_op));
ops->push_back(std::move(mutliply_op));
return beforeDropoutTensor;
}
static cudnn_frontend::Tensor createDropoutForward(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv,
double probability,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& beforeDropoutTensor) {
cudnn_frontend::throw_if(ops->size() == 0,
"Dropout DAG constructed incorrectly as the first one",
CUDNN_STATUS_BAD_PARAM);
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
// Mask for the dropout
auto dropoutMaskTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 250,
afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
auto dropoutSeedTensor = tensor_create(
CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"],
scale_dim, scale_stride, false, false); // is by value
auto dropoutOffsetTensor = tensor_create(
CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"],
scale_dim, scale_stride, false, false); // is by value
// After dropout tensor befor scale
auto beforeDropoutScaleTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(tensor_name_to_uid["VIRTUAL"] + 201)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual(true)
.setByValue(false)
.setReorderType(cudnn_frontend::cudnnBackendTensorReordering_t::
CUDNN_TENSOR_REORDERING_F16x16)
.build();
// Scale after dropout
auto scaleDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"],
scale_dim, scale_stride, false, true); // is by value
// After Scale
auto afterDropout_before_quan_S = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202,
afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
// Define the reduction descriptor
auto rngDesc = cudnn_frontend::RngDescBuilder()
.setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
.setBernoulliDistProbability(1.0 - probability)
.build();
// Create a rng Node
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor)
.setSeedDesc(dropoutSeedTensor)
.setOffsetDesc(dropoutOffsetTensor)
.setRngDesc(rngDesc)
.build();
// Define the multiply mask descriptor
auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask Node
auto maskMul_op = binary_pw_op_create(
beforeDropoutTensor, dropoutMaskTensor,
beforeDropoutScaleTensor, maskMulDesc);
// Define the multiply scale descriptor
auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask Node
auto scaleMul_op = binary_pw_op_create(
beforeDropoutScaleTensor, scaleDropoutTensor,
afterDropout_before_quan_S, scaleMulDesc);
ops->push_back(std::move(rng_op));
ops->push_back(std::move(maskMul_op));
ops->push_back(std::move(scaleMul_op));
return afterDropout_before_quan_S;
}
static cudnn_frontend::Tensor createDropoutBackward(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv,
double probability,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& beforeDropoutTensor,
const cudnn_frontend::Tensor& dropoutMaskTensor) {
cudnn_frontend::throw_if(ops->size() == 0,
"Dropout DAG constructed incorrectly as the first one",
CUDNN_STATUS_BAD_PARAM);
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
auto dropoutSeedTensor = tensor_create(
CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"],
scale_dim, scale_stride, false, false); // is by value
auto dropoutOffsetTensor = tensor_create(
CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"],
scale_dim, scale_stride, false, false); // is by value
// After dropout tensor befor scale
auto beforeDropoutScaleTensor = cudnn_frontend::TensorBuilder()
.setDim(4, afterBMM1_dim)
.setStride(4, afterBMM1_stride)
.setId(tensor_name_to_uid["VIRTUAL"] + 201)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual(true)
.setByValue(false)
.setReorderType(cudnn_frontend::cudnnBackendTensorReordering_t::
CUDNN_TENSOR_REORDERING_F16x16)
.build();
// Scale after dropout (1 / (1 - p))
auto scaleDropoutTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"],
scale_dim, scale_stride, false, true); // is by value
// After Scale
auto afterDropout_before_quan_S = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202,
afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
// Define the reduction descriptor
auto rngDesc = cudnn_frontend::RngDescBuilder()
.setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI)
.setBernoulliDistProbability(1.0 - probability)
.build();
// Create a rng Node
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor)
.setSeedDesc(dropoutSeedTensor)
.setOffsetDesc(dropoutOffsetTensor)
.setRngDesc(rngDesc)
.build();
// Define the multiply mask descriptor
auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask Node
auto maskMul_op = binary_pw_op_create(
beforeDropoutTensor, dropoutMaskTensor,
beforeDropoutScaleTensor, maskMulDesc);
// Define the multiply scale descriptor
auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask Node
auto scaleMul_op = binary_pw_op_create(
beforeDropoutScaleTensor, scaleDropoutTensor,
afterDropout_before_quan_S, scaleMulDesc);
ops->push_back(std::move(rng_op));
ops->push_back(std::move(maskMul_op));
ops->push_back(std::move(scaleMul_op));
return afterDropout_before_quan_S;
}
static cudnn_frontend::Tensor createSoftmaxBackward(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& dyTensor) {
cudnn_frontend::throw_if(ops->size() == 0,
"Softmax backward constructed incorrectly as the first one",
CUDNN_STATUS_BAD_PARAM);
int64_t dx_dim[4] = {b, h, s_q, s_kv};
int64_t dx_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
int64_t M_Z_dim[4] = {b, h, s_q, 1};
int64_t M_Z_stride[4] = {h * s_q, s_q, 1, 1};
// Creating all tensors
auto MTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["M"],
M_Z_dim, M_Z_stride, false, false); // not virtual
auto ZInvTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"],
M_Z_dim, M_Z_stride, false, false); // not virtual
auto dxAfterSubtractionTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 252,
dx_dim, dx_stride, true, false); // is virtual
auto dxAfterExponentiation = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 253,
dx_dim, dx_stride, true, false); // is virtual
auto dxBeforeDropout_QKt_Tensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 254,
dx_dim, dx_stride, true, false); // is virtual
// Creating all ops
// sub (dy - M)
auto subtractionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
auto subtractionOp = binary_pw_op_create(
dyTensor, MTensor, dxAfterSubtractionTensor, subtractionDesc);
// Define the exponent descriptor
auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP);
// Create a exponent Node. (exp(dy - M))
auto exponentOp = unary_pw_op_create(
dxAfterSubtractionTensor, dxAfterExponentiation, exponentDesc);
// Define the pw multiply descriptor
auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply Node
auto mutliplyOp = binary_pw_op_create(
dxAfterExponentiation, ZInvTensor, dxBeforeDropout_QKt_Tensor, multiplyDesc);
ops->push_back(std::move(subtractionOp));
ops->push_back(std::move(exponentOp));
ops->push_back(std::move(mutliplyOp));
return dxBeforeDropout_QKt_Tensor;
}
static cudnn_frontend::Tensor createQKBMM(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout,
cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor &qTensor,
const cudnn_frontend::Tensor &kTensor,
const cudnn_frontend::Tensor &mnkOverride,
std::shared_ptr<cudnn_frontend::Tensor> QKVRaggedOffsetTensor) {
// Creates the necessary tensor descriptors
int64_t k_transpose_dim[4] = {b, h, d, s_kv};
int64_t k_transpose_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d,
k_transpose_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose);
int64_t s_dim[4] = {b, h, s_q, s_kv};
int64_t s_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
auto kTransposeTensor = tensor_create_with_offset(
tensorType, tensor_name_to_uid["K_TRANSPOSE"],
k_transpose_dim, k_transpose_stride,
false, false, QKVRaggedOffsetTensor); // is virtual
// First GEMM output
auto afterQKTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 1,
s_dim, s_stride, true, false); // is virtual
// Define the matmul desc
auto matmulDesc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(-2000000)
.build();
// Create reshape node for K -> K.T
auto reshape_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(kTensor)
.setyDesc(kTransposeTensor)
.build();
// Create a matmul Node
auto matmulOp = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(qTensor)
.setbMatDesc(kTransposeTensor)
.setcMatDesc(afterQKTensor)
.setmOverrideDesc(mnkOverride)
.setnOverrideDesc(mnkOverride)
.setmatmulDesc(matmulDesc)
.build();
ops->push_back(std::move(reshape_op));
ops->push_back(std::move(matmulOp));
return afterQKTensor;
}
static cudnn_frontend::Tensor createSVBMM(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout,
cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor &softmaxTensor,
const cudnn_frontend::Tensor &mnkOverride,
std::shared_ptr<cudnn_frontend::Tensor> QKVRaggedOffsetTensor) {
cudnn_frontend::throw_if(ops->size() == 0,
"BMM2 op constructed incorrectly as the first one",
CUDNN_STATUS_BAD_PARAM);
int64_t v_dim[4] = {b, h, s_kv, d};
int64_t v_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
int64_t o_dim[4] = {b, h, s_q, d};
int64_t o_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
auto vTensor = tensor_create_with_offset(
tensorType, tensor_name_to_uid["V"],
v_dim, v_stride, false, false, QKVRaggedOffsetTensor);
// Second fprop GEMM output
auto oTensor = tensor_create(
tensorType, tensor_name_to_uid["VIRTUAL"] + 300,
o_dim, o_stride, true, false); // is virtual
// Define the matmul desc
auto matmulDesc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.build();
// Create a matmul Node
auto matmulOp = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(softmaxTensor)
.setbMatDesc(vTensor)
.setcMatDesc(oTensor)
.setmOverrideDesc(mnkOverride)
.setkOverrideDesc(mnkOverride)
.setmatmulDesc(matmulDesc)
.build();
ops->push_back(std::move(matmulOp));
return oTensor;
}
static cudnn_frontend::Tensor createSdOBMM(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor &softmaxTensor,
const cudnn_frontend::Tensor &dOTensor,
const cudnn_frontend::Tensor &mnkOverride) {
cudnn_frontend::throw_if(ops->size() == 0,
"BMM2 op constructed incorrectly as the first one",
CUDNN_STATUS_BAD_PARAM);
int64_t s_dim_transpose[4] = {b, h, s_kv, s_q};
int64_t s_stride_transpose[4] = {h * s_kv * s_q, s_kv * s_q, 1, s_kv};
int64_t v_dim[4] = {b, h, s_kv, d};
int64_t v_stride[4] = {h * s_kv * d, d, h * d, 1};
auto sTransposeTensor = tensor_create(
tensorType, tensor_name_to_uid["VIRTUAL"] + 499,
s_dim_transpose, s_stride_transpose,
true, false); // is virtual
// S.T * dO
auto dVTensor_before_dequan_S = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 500,
v_dim, v_stride,
true, false); // is virtual
// Create reshape node for softmax -> softmax.T
auto reshape_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(softmaxTensor)
.setyDesc(sTransposeTensor)
.build();
// Define the matmul desc
auto matmulDesc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0)
.build();
// Create a matmul Node
auto matmulOp = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(sTransposeTensor)
.setbMatDesc(dOTensor)
.setcMatDesc(dVTensor_before_dequan_S)
.setmOverrideDesc(mnkOverride)
.setkOverrideDesc(mnkOverride)
.setmatmulDesc(matmulDesc)
.build();
ops->push_back(std::move(reshape_op));
ops->push_back(std::move(matmulOp));
return dVTensor_before_dequan_S;
}
static cudnn_frontend::Tensor createdOVBMM(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout,
cudnnDataType_t tensorType,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor &dOTensor,
const cudnn_frontend::Tensor &mnkOverride,
std::shared_ptr<cudnn_frontend::Tensor> QKVRaggedOffsetTensor) {
// Creates the necessary tensor descriptors
int64_t v_dim[4] = {b, h, s_kv, d};
int64_t v_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
int64_t v_transpose_dim[4] = {b, h, d, s_kv};
int64_t v_transpose_stride[4];
v_transpose_stride[0] = v_stride[0];
v_transpose_stride[1] = v_stride[1];
v_transpose_stride[2] = v_stride[3];
v_transpose_stride[3] = v_stride[2];
int64_t s_dim[4] = {b, h, s_q, s_kv};
int64_t s_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
auto vTensor = tensor_create_with_offset(
tensorType, tensor_name_to_uid["V"],
v_dim, v_stride,
false, false, QKVRaggedOffsetTensor);
auto vTransposeTensor = tensor_create_with_offset(
tensorType, tensor_name_to_uid["V_TRANSPOSE"],
v_transpose_dim, v_transpose_stride,
false, false, QKVRaggedOffsetTensor); // is virtual
// dO * V.T
auto afterdOVTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 600,
s_dim, s_stride, true, false); // is virtual
// Define the matmul desc
auto matmulDesc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0)
.build();
// Create reshape node for V -> V.T
auto reshape_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(vTensor)
.setyDesc(vTransposeTensor)
.build();
// Create a matmul Node
auto matmulOp = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dOTensor)
.setbMatDesc(vTransposeTensor)
.setcMatDesc(afterdOVTensor)
.setmOverrideDesc(mnkOverride)
.setnOverrideDesc(mnkOverride)
.setmatmulDesc(matmulDesc)
.build();
ops->push_back(std::move(reshape_op));
ops->push_back(std::move(matmulOp));
return afterdOVTensor;
}
static cudnn_frontend::Tensor createdOAndORowReductionChain(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor &O_after_dequan,
const cudnn_frontend::Tensor &dO_after_dequan,
const cudnn_frontend::Tensor &dropoutScale_dOVt_OdO_Tensor) {
int64_t o_dim[4] = {b, h, s_q, d};
int64_t o_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t o_dim_row_sum[4] = {b, h, s_q, 1};
int64_t o_dim_row_sum_stride[4] = {s_q * h, s_q, 1, 1};
auto O_dO_after_pointwise_multiply = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 700,
o_dim, o_stride, true, false); // is virtual
auto O_dO_after_dropout_scale = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 701,
o_dim, o_stride, true, false); // is virtual
auto O_dO_after_rowsum = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 702,
o_dim_row_sum, o_dim_row_sum_stride, true, false); // is virtual
// Define the pw multiply descriptor
auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply Node
auto mutliply_op = binary_pw_op_create(
O_after_dequan, dO_after_dequan,
O_dO_after_pointwise_multiply, multiplyDesc);
// Create multiply node with dropout scale
auto dropout_scale_multiply_op = binary_pw_op_create(
O_dO_after_pointwise_multiply, dropoutScale_dOVt_OdO_Tensor,
O_dO_after_dropout_scale, multiplyDesc);
// Define the reduction descriptor
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
// Create a reduction add Node
auto reductionAdd_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(O_dO_after_dropout_scale)
.setyDesc(O_dO_after_rowsum)
.setreductionDesc(reductionAddDesc)
.build();
ops->push_back(std::move(mutliply_op));
ops->push_back(std::move(dropout_scale_multiply_op));
ops->push_back(std::move(reductionAdd_op));
return O_dO_after_rowsum;
}
static cudnn_frontend::Tensor createBiasSubtractionSoftmaxMulChain(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor &dS_after_dropout,
const cudnn_frontend::Tensor &AfterDropout_before_quan_S,
const cudnn_frontend::Tensor &O_dO_after_rowsum,
const cudnn_frontend::Tensor &attnScale) {
int64_t o_dim[4] = {b, h, s_q, s_kv};
int64_t o_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
auto dS_minus_O_dO = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 800,
o_dim, o_stride, true, false); // is virtual
auto AfterAttnScale_before_dS = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 801,
o_dim, o_stride, true, false); // is virtual
auto S_mul_dS_minus_O_dO = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 802,
o_dim, o_stride, true, false); // is virtual
// Define the pw subtraction descriptor
auto subDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB);
// Create a subtraction Node
auto sub_op = binary_pw_op_create(
dS_after_dropout, O_dO_after_rowsum, dS_minus_O_dO, subDesc);
// Define the pw multiplication descriptor
auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// dS_minus_O_dO * attnScale
auto mutliply_attn_scale_op = binary_pw_op_create(
dS_minus_O_dO, attnScale,
AfterAttnScale_before_dS, multiplyDesc);
// AfterDropout_before_quan_S * AfterAttnScale_before_dS
auto mutliply_op = binary_pw_op_create(
AfterDropout_before_quan_S, AfterAttnScale_before_dS,
S_mul_dS_minus_O_dO, multiplyDesc);
ops->push_back(std::move(sub_op));
ops->push_back(std::move(mutliply_attn_scale_op));
ops->push_back(std::move(mutliply_op));
return S_mul_dS_minus_O_dO;
}
static cudnn_frontend::Tensor createdSKBMM(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor &dSTensor,
const cudnn_frontend::Tensor &kTensor,
const cudnn_frontend::Tensor &mnkOverride) {
// Creates the necessary tensor descriptors
int64_t after_dSK_dim[4] = {b, h, s_kv, d};
int64_t after_dSK_stride[4] = {h * s_kv * d, d, h * d, 1};
// dS * K
auto After_dS_K = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 875,
after_dSK_dim, after_dSK_stride, true, false); // is virtual
// Define the matmul desc
auto matmulDesc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0)
.build();
// Create a matmul Node
auto matmulOp = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dSTensor)
.setbMatDesc(kTensor)
.setcMatDesc(After_dS_K)
.setmOverrideDesc(mnkOverride)
.setkOverrideDesc(mnkOverride)
.setmatmulDesc(matmulDesc)
.build();
ops->push_back(std::move(matmulOp));
return After_dS_K;
}
static cudnn_frontend::Tensor createdSQBMM(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor &dSTensor,
const cudnn_frontend::Tensor &qTensor,
const cudnn_frontend::Tensor &mnkOverride) {
// Creates the necessary tensor descriptors
int64_t dS_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dS_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix);
int64_t dS_transpose_dim[4] = {b, h, s_kv, s_q};
int64_t dS_transpose_stride[4];
dS_transpose_stride[0] = dS_stride[0];
dS_transpose_stride[1] = dS_stride[1];
dS_transpose_stride[2] = dS_stride[3];
dS_transpose_stride[3] = dS_stride[2];
int64_t after_dSTranspose_Q_dim[4] = {b, h, s_kv, d};
int64_t after_dSTranspose_Q_stride[4] = {h * s_kv * d, d, h * d, 1};
auto dSTransposeTensor = tensor_create(
CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["VIRTUAL"] + 650,
dS_transpose_dim, dS_transpose_stride, true, false); // is virtual
// dS.T * Q
auto After_dSTranspose_Q = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 651,
after_dSTranspose_Q_dim, after_dSTranspose_Q_stride,
true, false); // is virtual
// Create reshape node for V -> V.T
auto reshape_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(dSTensor)
.setyDesc(dSTransposeTensor)
.build();
// Define the matmul desc
auto matmulDesc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setPaddingValue(0)
.build();
// Create a matmul Node
auto matmulOp = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dSTransposeTensor)
.setbMatDesc(qTensor)
.setcMatDesc(After_dSTranspose_Q)
.setmOverrideDesc(mnkOverride)
.setkOverrideDesc(mnkOverride)
.setmatmulDesc(matmulDesc)
.build();
ops->push_back(std::move(reshape_op));
ops->push_back(std::move(matmulOp));
return After_dSTranspose_Q;
}
// fused attention FWD FP8
void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
bool isTraining, float attnScale,
float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv,
void* devPtrO,
void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV,
void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO,
void* devPtrAmaxO, void* devPtrAmaxS,
void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType,
void* workspace_ptr,
size_t* workspace_size,
cudaStream_t stream,
cudnnHandle_t handle_) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
FADescriptor descriptor{
b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability, layout, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static CacheType fa_fprop_cache;
// Get plan from cache if cache is available, otherwise create one
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
// If hit, return
auto it = cache.find(descriptor);
if (it != cache.end()) {
auto plan = it->second;
return plan;
}
// Otherwise, build the op_graph and the plan. Then update cache
std::vector<cudnn_frontend::Operation const*> all_ops;
std::vector<cudnn_frontend::Operation> ops;
cudnn_frontend::throw_if(dropoutProbability != 0.0f && !isTraining,
"Dropout probability should be 0.0f for inference mode",
CUDNN_STATUS_BAD_PARAM);
cudnn_frontend::throw_if(dropoutProbability == 1.0f,
"Dropout probability cannot be 1.0",
CUDNN_STATUS_BAD_PARAM);
int64_t raggedDim[4] = {b + 1, 1, 1, 1};
int64_t raggedStride[4] = {1, 1, 1, 1};
// Create offset tensors
auto QKVOffsetTensor = tensor_create(
CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"],
raggedDim, raggedStride, false, false);
auto ORaggedOffsetTensor = tensor_create(
CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"],
raggedDim, raggedStride, false, false);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
// Create override tensors
auto seqlenMNKTensor = tensor_create(
CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"],
seqlen_dim, seqlen_stride, false, false);
// Create shared ptrs to ragged offset tensors
// for multiple tensors to use ragged offset
std::shared_ptr<cudnn_frontend::Tensor> QKVRaggedOffsetTensorPtr =
std::make_shared<cudnn_frontend::Tensor>(std::move(QKVOffsetTensor));
std::shared_ptr<cudnn_frontend::Tensor> ORaggedOffsetTensorPtr =
std::make_shared<cudnn_frontend::Tensor>(std::move(ORaggedOffsetTensor));
// Create Q and K tensors that are used in different places
int64_t q_dim[4] = {b, h, s_q, d};
int64_t q_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout,
NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t k_dim[4] = {b, h, s_kv, d};
int64_t k_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout,
NVTE_QKV_Matrix::NVTE_K_Matrix);
auto qTensor = tensor_create_with_offset(
tensorType, tensor_name_to_uid["Q"],
q_dim, q_stride, false, false,
QKVRaggedOffsetTensorPtr);
auto kTensor = tensor_create_with_offset(
tensorType, tensor_name_to_uid["K"],
k_dim, k_stride, false, false,
QKVRaggedOffsetTensorPtr);
// Q * K.T
auto afterQKTensor = createQKBMM(
b, h, s_q, s_kv, d, layout, tensorType,
&ops, qTensor, kTensor,
seqlenMNKTensor, QKVRaggedOffsetTensorPtr);
// QK.T * attn scale
auto AfterAttnScale_before_dequan_Q_tensor = createScale(
afterQKTensor, // input tensor
"AttnScale", // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
true, // scale is by value
&ops);
// QK.T * attn scale * dequant_Q
auto AfterAttnScale_before_dequan_K_tensor = createScale(
AfterAttnScale_before_dequan_Q_tensor, // input tensor
"descaleQ", // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops);
// QK.T * attn scale * dequant_Q * dequant_K
auto AfterAttnScale_tensor = createScale(
AfterAttnScale_before_dequan_K_tensor, // input tensor
"descaleK", // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops);
auto BeforeDropoutTensor = createSoftmaxForward(
b, h, s_q, s_kv, &ops,
AfterAttnScale_tensor, isTraining);
auto AfterDropout_before_quan_S = createDropoutForward(
b, h, s_q, s_kv, dropoutProbability,
&ops, BeforeDropoutTensor);
// Amax for S
createAmax("amaxS", BeforeDropoutTensor, &ops);
// After softmax * dropout * scale S -> fp8 input to next bmm with V
auto AfterMultiplyDropout = createScale(
AfterDropout_before_quan_S, // input tensor
"scaleS", // scale tensor
tensorType, // output tensor type
true, // output is virtual
false, // scale is by value
&ops);
// After softmax * Dropout * V
auto OTensor_before_dequan_S_tensor = createSVBMM(
b, h, s_q, s_kv, d, layout, tensorType,
&ops, AfterMultiplyDropout,
seqlenMNKTensor, QKVRaggedOffsetTensorPtr);
// O * dequant_S
auto OTensor_before_dequan_V_tensor = createScale(
OTensor_before_dequan_S_tensor, // input tensor
"descaleS", // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops);
// O * dequant_S * dequant_V
auto OTensor_before_quan_O_tensor = createScale(
OTensor_before_dequan_V_tensor, // input tensor
"descaleV", // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops);
// O * dequant_S * dequant_V * scale O
auto OTensor = createScaleWithOffset(
OTensor_before_quan_O_tensor, // input tensor
"scaleO", // scale tensor
tensorType, // output tensor type
false, // output not virtual
false, // scale is by value
&ops,
ORaggedOffsetTensorPtr, // ragged offset
"O");
// Amax for O
createAmax("amaxO", OTensor_before_quan_O_tensor, &ops);
for (unsigned int i = 0; i < ops.size(); i++) {
all_ops.push_back(&ops[i]);
}
// Create an Operation Graph
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(all_ops.size(), all_ops.data())
.build();
cudnn_frontend::EngineConfigList filtered_configs;
auto statuses = cudnn_frontend::get_heuristics_list<1>(
{"heuristics_instant"}, opGraph,
allowAllConfig, filtered_configs, true);
if (filtered_configs.size() == 0) {
cudnn_frontend::set_error_and_throw_exception(
nullptr,
CUDNN_STATUS_NOT_SUPPORTED,
"run_mha_fprop: No config returned by the heuristics");
}
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle_)
.setEngineConfig(filtered_configs[0], opGraph.getTag())
.build();
cache.insert({descriptor, plan});
return plan;
}; // end of get_plan
auto plan = get_plan(fa_fprop_cache, descriptor);
size_t wkspace_size = static_cast<size_t>(plan.getWorkspaceSize());
// Exit to request upper level API to allocate memory if needed
if (workspace_ptr == nullptr) {
*workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t);
return;
}
int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size);
int32_t* o_ragged_offset = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr)
+ wkspace_size + (b + 1) * sizeof(int32_t));
int32_t* actual_seqlens_q = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr)
+ wkspace_size + (b + 1) * 2 * sizeof(int32_t));
// FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV
dim3 blockDims(128);
dim3 gridDims((b + blockDims.x)/blockDims.x);
cu_seqlens_to_offsets<<<gridDims, blockDims, 0, stream>>>(
b, h, d, reinterpret_cast<int32_t*>(devPtrcuSeqlensQ),
actual_seqlens_q, qkv_ragged_offset, o_ragged_offset);
void* devPtrQKVRaggedOffset = reinterpret_cast<void *>(qkv_ragged_offset);
void* devPtrORaggedOffset = reinterpret_cast<void *>(o_ragged_offset);
void* devPtrMNKOverride = reinterpret_cast<void *>(actual_seqlens_q);
float dropoutScale = 1.0f/(1.0f - dropoutProbability);
std::set<std::pair<uint64_t, void*>> data_ptrs;
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["Q"], devPtrQ));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["K"], devPtrK));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["K_TRANSPOSE"], devPtrK));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["V"], devPtrV));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["AttnScale"], &attnScale));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["O"], devPtrO));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["descaleQ"], devPtrDescaleQ));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["descaleK"], devPtrDescaleK));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["descaleV"], devPtrDescaleV));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["descaleS"], devPtrDescaleS));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["scaleS"], devPtrScaleS));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["scaleO"], devPtrScaleO));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["amaxO"], devPtrAmaxO));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["amaxS"], devPtrAmaxS));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride));
// If training, then we need to write out M and Z_INV
if (isTraining) {
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["M"], devPtrM));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["Z_INV"], devPtrZInv));
}
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(data_ptrs)
.build();
cudnnStatus_t status = cudnnBackendExecute(
handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
cudnn_frontend::throw_if(
[status]() { return (status != CUDNN_STATUS_SUCCESS); },
"Plan execute error", status);
} catch (cudnn_frontend::cudnnException& e) {
struct cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0));
// This example is only for GH100 cards (cudnn Version >= 8900)
if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900))
&& (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH
|| e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) {
std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl;
} else {
std::cout << "[ERROR] Exception " << e.what() << std::endl;
}
}
}
// fused attention BWD FP8
void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
float attnScale, float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv,
void* devPtrO, void* devPtrdO,
void* devPtrdQ, void* devPtrdK, void* devPtrdV,
void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV,
void* devPtrDescaleO, void* devPtrDescaledO,
void* devPtrDescaleS, void* devPtrDescaledS,
void* devPtrScaleS, void* devPtrScaledS,
void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV,
void* devPtrAmaxdS,
void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV,
void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV,
void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType,
void* workspace_ptr,
size_t* workspace_size,
cudaStream_t stream,
cudnnHandle_t handle_) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));
FADescriptor descriptor{
b, h, s_q, s_kv, d,
attnScale, false, dropoutProbability, layout, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static CacheType fa_bprop_cache;
// Get plan from cache if cache is available, otherwise create one
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
// If hit, return
auto it = cache.find(descriptor);
if (it != cache.end()) {
auto plan = it->second;
return plan;
}
// Otherwise, build the op_graph and the plan. Then update cache
std::vector<cudnn_frontend::Operation const*> all_ops;
std::vector<cudnn_frontend::Operation> ops;
cudnn_frontend::throw_if(dropoutProbability == 1.0f,
"Dropout probability cannot be 1.0",
CUDNN_STATUS_BAD_PARAM);
int64_t raggedDim[4] = {b + 1, 1, 1, 1};
int64_t raggedStride[4] = {1, 1, 1, 1};
// Create offset tensors
auto QKVOffsetTensor = tensor_create(
CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"],
raggedDim, raggedStride, false, false);
auto ORaggedOffsetTensor = tensor_create(
CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"],
raggedDim, raggedStride, false, false);
// Create shared ptrs to ragged offset tensors for multiple tensors
std::shared_ptr<cudnn_frontend::Tensor> QKVRaggedOffsetTensorPtr =
std::make_shared<cudnn_frontend::Tensor>(std::move(QKVOffsetTensor));
std::shared_ptr<cudnn_frontend::Tensor> ORaggedOffsetTensorPtr =
std::make_shared<cudnn_frontend::Tensor>(std::move(ORaggedOffsetTensor));
// Create Q and K tensors that are used in different places
int64_t q_dim[4] = {b, h, s_q, d};
int64_t q_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout,
NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t k_dim[4] = {b, h, s_kv, d};
int64_t k_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout,
NVTE_QKV_Matrix::NVTE_K_Matrix);
auto qTensor = tensor_create_with_offset(
tensorType, tensor_name_to_uid["Q"],
q_dim, q_stride, false, false, QKVRaggedOffsetTensorPtr);
auto kTensor = tensor_create_with_offset(
tensorType, tensor_name_to_uid["K"],
k_dim, k_stride, false, false, QKVRaggedOffsetTensorPtr);
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
// Create attnScale tensor for multiple ops to use
auto attnScaleTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["AttnScale"],
scale_dim, scale_stride, false, true); // is by value
// Create descale Q K dO dS global tensors since they are used in multiple places
auto descaleQTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleQ"],
scale_dim, scale_stride, false, false);
auto descaleKTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleK"],
scale_dim, scale_stride, false, false);
auto descaledOTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledO"],
scale_dim, scale_stride, false, false);
auto descaledSTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledS"],
scale_dim, scale_stride, false, false);
int64_t seqlen_dim[4] = {b, 1, 1, 1};
int64_t seqlen_stride[4] = {1, 1, 1, 1};
// Create MNK override tensor
auto seqlenMNKTensor = tensor_create(
CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"],
seqlen_dim, seqlen_stride, false, false);
int64_t O_dim[4] = {b, h, s_q, d};
int64_t O_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, O_stride, layout,
NVTE_QKV_Matrix::NVTE_O_Matrix);
// Create O and loss tensor
auto OTensor = tensor_create_with_offset(
tensorType, tensor_name_to_uid["O"],
O_dim, O_stride, false, false, ORaggedOffsetTensorPtr);
// dO is used in multiple places and E5M2
auto dOTensor = tensor_create_with_offset(
CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["dO"],
O_dim, O_stride, false, false, ORaggedOffsetTensorPtr);
// Q * K.T
auto afterQKTensor = createQKBMM(
b, h, s_q, s_kv, d, layout, tensorType,
&ops, qTensor, kTensor,
seqlenMNKTensor, QKVRaggedOffsetTensorPtr);
// QK.T * attn scale
auto AfterAttnScale_before_dequan_Q_tensor = createScale(
afterQKTensor, // input tensor
attnScaleTensor, // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
true, // scale is by value
&ops,
1999 /*UID offset*/);
// QK.T * attn scale * dequant_Q
auto AfterAttnScale_before_dequan_K_tensor = createScale(
AfterAttnScale_before_dequan_Q_tensor, // input tensor
descaleQTensor, // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops,
2000 /*UID offset*/);
// QK.T * attn scale * dequant_Q * dequant_K
auto AfterAttnScale_tensor = createScale(
AfterAttnScale_before_dequan_K_tensor, // input tensor
descaleKTensor, // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops,
2001 /*UID offset*/);
auto beforeDropout_QKt_Tensor = createSoftmaxBackward(
b, h, s_q, s_kv, &ops, AfterAttnScale_tensor);
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
// mask for the dropout. Used in different places
auto dropoutMaskTensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 200,
afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
auto AfterDropout_before_quan_S = createDropoutBackward(
b, h, s_q, s_kv, dropoutProbability,
&ops, beforeDropout_QKt_Tensor, dropoutMaskTensor);
// After softmax * scale S -> fp8 input to next bmm with V
auto AfterMultiply = createScale(
AfterDropout_before_quan_S, // input tensor
"scaleS", // scale tensor
tensorType, // output tensor type
true, // output is virtual
false, // scale is by value
&ops);
// After softmax * dO
auto dVTensor_before_dequan_S = createSdOBMM(
b, h, s_q, s_kv, d, tensorType,
&ops, AfterMultiply, dOTensor, seqlenMNKTensor);
// O * dequant_S
auto dVTensor_before_dequan_dO = createScale(
dVTensor_before_dequan_S, // input tensor
"descaleS", // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops);
// O * dequant_S * dequant_dO
auto dVTensor_before_quan_dV = createScale(
dVTensor_before_dequan_dO, // input tensor
descaledOTensor, // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops,
2002 /*UID offset*/);
// O * dequant_S * dequant_dO * scale dV
auto dVTensor = createScaleWithOffset(
dVTensor_before_quan_dV, // input tensor
"scaledV", // scale tensor
CUDNN_DATA_FP8_E5M2, // output tensor type
false, // output not virtual
false, // scale is by value
&ops,
QKVRaggedOffsetTensorPtr, // ragged offset
"dV" /*Output tensor name*/);
// Amax for dV
createAmax("amaxdV", dVTensor_before_quan_dV, &ops);
auto dS_before_dequan_dO_Tensor = createdOVBMM(
b, h, s_q, s_kv, d, layout, tensorType,
&ops, dOTensor, seqlenMNKTensor, QKVRaggedOffsetTensorPtr);
// dS * dequant_dO
auto dS_before_dequan_V = createScale(
dS_before_dequan_dO_Tensor, // input tensor
descaledOTensor, // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops,
2003 /*UID offset*/);
// O * dequant_S * dequant_dV
auto dS_after_dequan = createScale(
dS_before_dequan_V, // input tensor
"descaleV", // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops);
// RNG Multiply
auto beforeDropoutScale_dOVt_Tensor = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 350,
afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
// After dropout mask and scale
auto dS_after_dropout = tensor_create(
CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 351,
afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
// Define the multiply mask descriptor
auto mulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL);
// Create a multiply mask Node
auto maskMul_op = binary_pw_op_create(
dS_after_dequan, dropoutMaskTensor,
beforeDropoutScale_dOVt_Tensor, mulDesc);
ops.push_back(std::move(maskMul_op));
// scale after dropout for dO and O chain
auto dropoutScale_dOVt_OdO_Tensor = tensor_create(
tensorType, tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"],
scale_dim, scale_stride, false, true); // is by value
// Create a multiply dropout scale Node
auto mul_dropout_scale_op = binary_pw_op_create(
beforeDropoutScale_dOVt_Tensor,
dropoutScale_dOVt_OdO_Tensor,
dS_after_dropout, mulDesc);
ops.push_back(std::move(mul_dropout_scale_op));
// O * dequant_O
auto O_after_dequan_Tensor = createScale(OTensor, // input tensor
"descaleO", // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops);
// dO * dequant_dO
auto dO_after_dequan_Tensor = createScale(dOTensor, // input tensor
descaledOTensor, // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops,
2004 /*UID offset*/);
// row reduction sum[(dO * dequant_dO) * (O * dequant_O) * (1 - p)]
auto O_dO_after_rowsum = createdOAndORowReductionChain(
b, h, s_q, s_kv, d, layout,
&ops, O_after_dequan_Tensor,
dO_after_dequan_Tensor, dropoutScale_dOVt_OdO_Tensor);
// (dS_after_dropout - O_dO_after_rowsum) * AfterDropout_before_quan_S * attnScale
auto S_mul_dS_minus_O_dO = createBiasSubtractionSoftmaxMulChain(
b, h, s_q, s_kv, d, layout,
&ops, dS_after_dropout,
AfterDropout_before_quan_S, O_dO_after_rowsum,
attnScaleTensor);
// S_mul_dS_minus_O_dO * scaledS
auto S_mul_dS_minus_O_dO_after_quan_dS = createScale(
S_mul_dS_minus_O_dO, // input tensor
"scaledS", // scale tensor
CUDNN_DATA_FP8_E5M2, // output tensor type
true, // output is virtual
false, // scale is by value
&ops);
// Amax for dS
createAmax("amaxdS", S_mul_dS_minus_O_dO, &ops);
// dS @ K
auto After_dS_K = createdSKBMM(
b, h, s_q, s_kv, d, &ops,
S_mul_dS_minus_O_dO_after_quan_dS,
kTensor, seqlenMNKTensor);
// (dS * K) * descale dS
auto After_dS_K_before_dequan_K = createScale(
After_dS_K, // input tensor
descaledSTensor, // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops,
2006 /*UID offset*/);
// (dS * K) * descale dS * descale K
auto After_dS_K_before_quan_dQ = createScale(
After_dS_K_before_dequan_K, // input tensor
descaleKTensor, // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops,
2007 /*UID offset*/);
// (dS * K) * descale dS * descale K * scale dQ
auto dQ = createScaleWithOffset(
After_dS_K_before_quan_dQ, // input tensor
"scaledQ", // scale tensor
CUDNN_DATA_FP8_E5M2, // output tensor type
false, // output not virtual
false, // scale is by value
&ops,
QKVRaggedOffsetTensorPtr, // ragged offset
"dQ");
// Amax for dQ
createAmax("amaxdQ", After_dS_K_before_quan_dQ, &ops);
// dS.T @ Q
auto After_dSTranspose_Q = createdSQBMM(
b, h, s_q, s_kv, d, layout, &ops,
S_mul_dS_minus_O_dO_after_quan_dS,
qTensor, seqlenMNKTensor);
// (dS.T * Q) * descale dS
auto After_dSTranspose_Q_before_dequan_Q = createScale(
After_dSTranspose_Q, // input tensor
descaledSTensor, // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops,
2009 /*UID offset*/);
// (dS.T * Q) * descale dS * descale Q
auto After_dSTranspose_Q_before_quan_dK = createScale(
After_dSTranspose_Q_before_dequan_Q, // input tensor
descaleQTensor, // scale tensor
CUDNN_DATA_FLOAT, // output tensor type
true, // output is virtual
false, // scale is by value
&ops,
2010 /*UID offset*/);
// (dS.T * Q) * descale dS * descale Q * scale dK
auto dK = createScaleWithOffset(
After_dSTranspose_Q_before_quan_dK, // input tensor
"scaledK", // scale tensor
CUDNN_DATA_FP8_E5M2, // output tensor type
false, // output not virtual
false, // scale is by value
&ops,
QKVRaggedOffsetTensorPtr, // ragged offset
"dK");
// Amax for dK
createAmax("amaxdK", After_dSTranspose_Q_before_quan_dK, &ops);
for (unsigned int i = 0; i < ops.size(); i++) {
all_ops.push_back(&ops[i]);
}
// Create an Operation Graph
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(all_ops.size(), all_ops.data())
.build();
cudnn_frontend::EngineConfigList filtered_configs;
auto statuses = cudnn_frontend::get_heuristics_list<1>(
{"heuristics_instant"}, opGraph,
allowAllConfig, filtered_configs, true);
if (filtered_configs.size() == 0) {
cudnn_frontend::set_error_and_throw_exception(
nullptr,
CUDNN_STATUS_NOT_SUPPORTED,
"run_mha_bprop: No config returned by the heuristics");
}
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle_)
.setEngineConfig(filtered_configs[0], opGraph.getTag())
.build();
cache.insert({descriptor, plan});
return plan;
};
auto plan = get_plan(fa_bprop_cache, descriptor);
size_t wkspace_size = static_cast<size_t>(plan.getWorkspaceSize());
// Exit to request upper level API to allocate memory if needed
if (workspace_ptr == nullptr) {
*workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t);
return;
}
int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size);
int32_t* o_ragged_offset = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr)
+ wkspace_size + (b + 1) * sizeof(int32_t));
int32_t* actual_seqlens_q = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr)
+ wkspace_size + (b + 1) * 2 * sizeof(int32_t));
// FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV
dim3 blockDims(128);
dim3 gridDims((b + blockDims.x)/blockDims.x);
cu_seqlens_to_offsets<<<gridDims, blockDims, 0, stream>>>(
b, h, d, reinterpret_cast<int32_t*>(devPtrcuSeqlensQ),
actual_seqlens_q, qkv_ragged_offset, o_ragged_offset);
void* devPtrQKVRaggedOffset = reinterpret_cast<void *>(qkv_ragged_offset);
void* devPtrORaggedOffset = reinterpret_cast<void *>(o_ragged_offset);
void* devPtrMNKOverride = reinterpret_cast<void *>(actual_seqlens_q);
std::set<std::pair<uint64_t, void*>> data_ptrs;
float dropoutScale = 1.0f/(1.0f - dropoutProbability);
float dropoutScale_dOVt_OdO = 1.0f - dropoutProbability;
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["Q"], devPtrQ));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["K"], devPtrK));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["K_TRANSPOSE"], devPtrK));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["V"], devPtrV));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["V_TRANSPOSE"], devPtrV));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["dQ"], devPtrdQ));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["dK"], devPtrdK));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["dV"], devPtrdV));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["dO"], devPtrdO));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["AttnScale"], &attnScale));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"],
&dropoutScale_dOVt_OdO));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["M"], devPtrM));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["Z_INV"], devPtrZInv));
data_ptrs.emplace(std::pair<uint64_t, void*>(tensor_name_to_uid["O"], devPtrO));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["descaleQ"], devPtrDescaleQ));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["descaleK"], devPtrDescaleK));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["descaleV"], devPtrDescaleV));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["descaleS"], devPtrDescaleS));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["descaledS"], devPtrDescaledS));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["descaleO"], devPtrDescaleO));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["descaledO"], devPtrDescaledO));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["scaleS"], devPtrScaleS));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["scaledS"], devPtrScaledS));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["scaledQ"], devPtrScaledQ));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["scaledK"], devPtrScaledK));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["scaledV"], devPtrScaledV));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["amaxdS"], devPtrAmaxdS));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["amaxdQ"], devPtrAmaxdQ));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["amaxdK"], devPtrAmaxdK));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["amaxdV"], devPtrAmaxdV));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset));
data_ptrs.emplace(std::pair<uint64_t, void*>(
tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride));
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(data_ptrs)
.build();
cudnnStatus_t status = cudnnBackendExecute(
handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
cudnn_frontend::throw_if(
[status]() { return (status != CUDNN_STATUS_SUCCESS); },
"Plan execute error", status);
} catch (cudnn_frontend::cudnnException& e) {
struct cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0));
// This example is only for GH100 cards (cudnn Version >= 8900)
if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900))
&& (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH
|| e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) {
std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl;
} else {
std::cout << "[ERROR] Exception " << e.what() << std::endl;
}
}
}
#endif
} // namespace fused_attn
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fwd_fp8_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_Output_Tensors,
const Tensor *cu_seqlens,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
// QKV shape is [total_seqs, 3, h, d]
void* devPtrQKV = input_QKV->data.dptr;
void* devPtrQ = reinterpret_cast<void *>(devPtrQKV);
void* devPtrK = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrQKV) + h * d);
void* devPtrV = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrQKV) + 2 * h * d);
void* devPtrDescaleQ = input_QKV->scale_inv.dptr;
void* devPtrDescaleK = input_QKV->scale_inv.dptr;
void* devPtrDescaleV = input_QKV->scale_inv.dptr;
void* devPtrO = output_O->data.dptr;
void* devPtrAmaxO = output_O->amax.dptr;
void* devPtrScaleO = output_O->scale.dptr;
void* devPtrM = nullptr;
void* devPtrZInv = nullptr;
if (Aux_Output_Tensors->size == 0) {
if (is_training) {
Aux_Output_Tensors->size = 2;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[1]);
output_M->data.dptr = nullptr;
output_M->data.shape = {b, h, max_seqlen, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {b, h, max_seqlen, 1};
output_ZInv->data.dtype = DType::kFloat32;
}
} else if (Aux_Output_Tensors->size == 2) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[1]);
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
}
void* devPtrAmaxS = input_output_S->amax.dptr;
void* devPtrScaleS = input_output_S->scale.dptr;
void* devPtrDescaleS = input_output_S->scale_inv.dptr;
void* devPtrcuSeqlens = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens->data.dptr));
void* devPtrDropoutSeed = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn::fa_fwd_fp8(
b, max_seqlen, max_seqlen, h, d,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleS, devPtrScaleS, devPtrScaleO,
devPtrAmaxO, devPtrAmaxS,
devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = { workspace_size };
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = { 1 };
workspace->data.dtype = DType::kByte;
return;
}
}
// fused attention BWD FP8 with packed QKV
void fused_attn_bwd_fp8_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV,
const Tensor *input_O,
const Tensor *input_dO,
const Tensor *input_M,
const Tensor *input_ZInv,
const Tensor *input_S,
Tensor *input_output_dP,
const Tensor *output_dQKV,
const Tensor *cu_seqlens,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle) {
using namespace transformer_engine;
// QKV shape is [total_seqs, 3, h, d]
void* devPtrQKV = input_QKV->data.dptr;
void* devPtrQ = reinterpret_cast<void *>(devPtrQKV);
void* devPtrK = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrQKV) + h * d);
void* devPtrV = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrQKV) + 2 * h * d);
void* devPtrDescaleQ = input_QKV->scale_inv.dptr;
void* devPtrDescaleK = input_QKV->scale_inv.dptr;
void* devPtrDescaleV = input_QKV->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
void* devPtrDescaleO = input_O->scale_inv.dptr;
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
void* devPtrM = input_M->data.dptr;
void* devPtrZInv = input_ZInv->data.dptr;
void* devPtrScaleS = input_S->scale.dptr;
void* devPtrDescaleS = input_S->scale_inv.dptr;
void* devPtrAmaxdS = input_output_dP->amax.dptr;
void* devPtrScaledS = input_output_dP->scale.dptr;
void* devPtrDescaledS = input_output_dP->scale_inv.dptr;
// dQKV shape is [total_seqs, 3, h, d]
void* devPtrdQKV = output_dQKV->data.dptr;
void* devPtrdQ = reinterpret_cast<void *>(devPtrdQKV);
void* devPtrdK = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrdQKV) + h * d);
void* devPtrdV = reinterpret_cast<void *>(reinterpret_cast<int8_t*>(devPtrdQKV) + 2 * h * d);
void* devPtrAmaxdQ = output_dQKV->amax.dptr;
void* devPtrAmaxdK = output_dQKV->amax.dptr;
void* devPtrAmaxdV = output_dQKV->amax.dptr;
void* devPtrScaledQ = output_dQKV->scale.dptr;
void* devPtrScaledK = output_dQKV->scale.dptr;
void* devPtrScaledV = output_dQKV->scale.dptr;
void* devPtrcuSeqlens = reinterpret_cast<void *>(
reinterpret_cast<int32_t*>(cu_seqlens->data.dptr));
void* devPtrDropoutSeed = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr));
void* devPtrDropoutOffset = reinterpret_cast<void *>(
reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn::fa_bwd_fp8(
b, max_seqlen, max_seqlen, h, d,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
devPtrM, devPtrZInv,
devPtrO, devPtrdO,
devPtrdQ, devPtrdK, devPtrdV,
devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV,
devPtrDescaleO, devPtrDescaledO,
devPtrDescaleS, devPtrDescaledS,
devPtrScaleS, devPtrScaledS,
devPtrScaledQ, devPtrScaledK, devPtrScaledV,
devPtrAmaxdS,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV,
devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = { workspace_size };
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = { 1 };
workspace->data.dtype = DType::kByte;
return;
}
}
#endif // end of CUDNN>=8900
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fwd_fp8_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_Output_Tensors,
const Tensor *cu_seqlens,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
// fused attention BWD FP8 with packed QKV
void fused_attn_bwd_fp8_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV,
const Tensor *input_O,
const Tensor *input_dO,
const Tensor *input_M,
const Tensor *input_ZInv,
const Tensor *input_S,
Tensor *input_output_dP,
const Tensor *output_dQKV,
const Tensor *cu_seqlens,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
#endif // end of CUDNN>=8900
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/fused_attn.h"
#include "../common.h"
#include "utils.h"
namespace transformer_engine {
namespace fused_attn {
using namespace transformer_engine;
// get matrix strides based on matrix type
void generateMatrixStrides(
int64_t b, int64_t h,
int64_t s_q, int64_t s_kv,
int64_t d, int64_t* strideA,
NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) {
constexpr int batch_dim_idx = 0;
constexpr int head_dim_idx = 1;
constexpr int seqlen_dim_idx = 2;
constexpr int hidden_dim_idx = 3;
constexpr int seqlen_transpose_dim_idx = 3;
constexpr int hidden_transpose_dim_idx = 2;
constexpr int seqlen_q_dim_idx = 2;
constexpr int seqlen_kv_dim_idx = 3;
switch (matrix) {
case NVTE_QKV_Matrix::NVTE_Q_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * 3 * h * d;
} else {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_K_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else {
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = 3 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else {
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_V_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = 2* h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = 2* h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_S_Matrix:
strideA[seqlen_kv_dim_idx] = 1;
strideA[seqlen_q_dim_idx] = s_kv;
strideA[head_dim_idx] = s_q * s_kv;
strideA[batch_dim_idx] = h * s_q * s_kv;
break;
case NVTE_QKV_Matrix::NVTE_O_Matrix:
strideA[seqlen_kv_dim_idx] = 1;
strideA[seqlen_q_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * h * d;
break;
}
}
// convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
int32_t *qkv_ragged_offset, int32_t *o_ragged_offset) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b) {
actual_seqlens_q[tid] = cu_seqlens_q[tid + 1] - cu_seqlens_q[tid];
}
if (tid < b + 1) {
qkv_ragged_offset[tid] = cu_seqlens_q[tid] * 3 * h * d;
o_ragged_offset[tid] = cu_seqlens_q[tid] * h * d;
}
}
} // namespace fused_attn
// get cuDNN data type
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return CUDNN_DATA_HALF;
case DType::kFloat32:
return CUDNN_DATA_FLOAT;
case DType::kBFloat16:
return CUDNN_DATA_BFLOAT16;
case DType::kFloat8E4M3:
return CUDNN_DATA_FP8_E4M3;
case DType::kFloat8E5M2:
return CUDNN_DATA_FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#include "transformer_engine/transformer_engine.h"
#include <cudnn_frontend.h>
namespace transformer_engine {
namespace fused_attn {
using namespace transformer_engine;
enum NVTE_QKV_Matrix {
NVTE_Q_Matrix = 0, // queries
NVTE_K_Matrix = 1, // keys
NVTE_K_Matrix_Transpose = 2, // keys transposed
NVTE_V_Matrix = 3, // values
NVTE_V_Matrix_Transpose = 4, // value matrix transposed
NVTE_S_Matrix = 5, // output of GEMM1
NVTE_O_Matrix = 6, // final output
};
void generateMatrixStrides(
int64_t b, int64_t h,
int64_t s_q, int64_t s_kv,
int64_t d, int64_t* strideA,
NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix);
struct FADescriptor {
std::int64_t b;
std::int64_t h;
std::int64_t s_q;
std::int64_t s_kv;
std::int64_t d;
float attnScale;
bool isTraining;
float dropoutProbability;
NVTE_QKV_Layout layout;
cudnnDataType_t tensor_type;
bool operator<(const FADescriptor &rhs) const {
return std::tie(b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability,
layout, tensor_type) < std::tie(
rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d,
rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.tensor_type);
}
};
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
int32_t *qkv_ragged_offset, int32_t *o_ragged_offset);
} // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
class cudnnExecutionPlanManager {
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}
cudnnHandle_t GetCudnnHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}
~cudnnExecutionPlanManager() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] {
if (handle_ != nullptr) {
cudnnDestroy(handle_);
}});
}
private:
cudnnHandle_t handle_ = nullptr;
};
} // namespace transformer_engine
#endif
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
enum NVTE_QKV_Layout {
/*!< separate Q, K, V tensors:
Q: [total_seqs_q, num_heads, head_dim]
| Q Q Q ... Q
| \___________ _____________/
total_seqs_q <| \/
| num_heads * head_dim
K: [total_seqs_kv, num_heads, head_dim]
| K K K ... K
| \___________ _____________/
total_seqs_kv <| \/
| num_heads * head_dim
V: [total_seqs_kv, num_heads, head_dim]
| V V V ... V
| \___________ _____________/
total_seqs_kv <| \/
| num_heads * head_dim
*/
NVTE_NOT_INTERLEAVED = 0,
/*!< packed QKV tensor:
QKV: [total_seqs, 3, num_heads, head_dim]
| Q Q Q ... Q K K K ... K V V V ... V
| \___________ _____________/
total_seqs <| \/
| num_heads * head_dim
*/
NVTE_QKV_INTERLEAVED = 1,
/*!< Q and packed KV tensor:
Q: [total_seqs_q, num_heads, head_dim]
| Q Q Q ... Q
| \___________ _____________/
total_seqs_q <| \/
| num_heads * head_dim
KV: [total_seqs_kv, 2, num_heads, head_dim]
| K K K ... K V V V ... V
| \___________ _____________/
total_seqs_kv <| \/
| num_heads * head_dim
*/
NVTE_KV_INTERLEAVED = 2
};
enum NVTE_Bias_Type {
NVTE_NO_BIAS = 0, /*!< no bias */
NVTE_PRE_SCALE_BIAS = 1, /*!< bias before scale */
NVTE_POST_SCALE_BIAS = 2 /*!< bias after scale */
};
enum NVTE_Mask_Type {
NVTE_PADDING_MASK = 0, /*!< padding attention mask */
NVTE_CAUSAL_MASK = 1, /*!< causal attention mask */
NVTE_NO_MASK = 2 /*!< no masking */
};
/*! \brief Compute dot product attention with packed QKV input.
*
* Computes:
* - P = Q * K.T + Bias
* - S = ScaleMaskSoftmax(P)
* - D = Dropout(S)
* - O = D * V.T
*
* Support Matrix:
* | precision | qkv layout | bias | mask | sequence length | head_dim |
* | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | <= 512 | 64 |
*
*
* \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim].
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens).
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
const NVTETensor cu_seqlens,
const NVTETensor rng_state,
size_t max_seqlen,
bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
* Support Matrix:
* | precision | qkv layout | bias | mask | sequence length | head_dim |
* | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | <= 512 | 64 |
*
*
* \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim].
* \param[in] dBias The gradient of the Bias tensor.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens).
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV,
const NVTETensor dBias,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQKV,
const NVTETensor cu_seqlens,
size_t max_seqlen,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input.
*
* Computes:
* - P = Q * K.T + Bias
* - S = ScaleMaskSoftmax(P)
* - D = Dropout(S)
* - O = D * V.T
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(cu_seqlens_kv).
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q,
const NVTETensor KV,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] dBias The gradient of the Bias tensor.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode.
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(cu_seqlens_kv).
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q,
const NVTETensor KV,
const NVTETensor dBias,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQ,
NVTETensor dKV,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cudnn.h>
#include <string> #include <string>
#include <stdexcept> #include <stdexcept>
...@@ -39,10 +40,18 @@ inline void check_cublas_(cublasStatus_t status) { ...@@ -39,10 +40,18 @@ inline void check_cublas_(cublasStatus_t status) {
} }
} }
inline void check_cudnn_(cudnnStatus_t status) {
if ( status != CUDNN_STATUS_SUCCESS ) {
NVTE_ERROR("CUDNN Error: " + std::string(cudnnGetErrorString(status)));
}
}
} // namespace } // namespace
#define NVTE_CHECK_CUDA(ans) { check_cuda_(ans); } #define NVTE_CHECK_CUDA(ans) { check_cuda_(ans); }
#define NVTE_CHECK_CUBLAS(ans) { check_cublas_(ans); } #define NVTE_CHECK_CUBLAS(ans) { check_cublas_(ans); }
#define NVTE_CHECK_CUDNN(ans) { check_cudnn_(ans); }
#endif // TRANSFORMER_ENGINE_LOGGING_H_ #endif // TRANSFORMER_ENGINE_LOGGING_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment