"vscode:/vscode.git/clone" did not exist on "a0f443548463b83c0678a56aebbffbcb13b0a5e0"
Unverified Commit 989a53a0 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Add FP8 fused attention (#155)



* Add FP8 fused attention to TE for PyTorch
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add license for cudnn-frontend, modify installation requirements, and refactor some headers for aesthetics
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add c api docs for fused attention
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add exception for unsupported precision/sequence length combinations
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix installation requirement for non fused attn use cases
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix docs for fused-attn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* prefix enums with NVTE_ and replace old MHA_Matrix with NVTE_QKV_Matrix
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* minor fixes based on PR comments
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix description for kvpacked fwd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix description of Bias in C api
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* minor fixes for cudnn requirement and description for QKV tensors
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix QKV layout description and support matrix for C api
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add asserts to cpp_extensions for qkv layout/bias type/attn mask type
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix typo precision
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c3407300
......@@ -17,6 +17,8 @@ jobs:
steps:
- name: 'Checkout'
uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build'
run: |
mkdir -p wheelhouse && \
......@@ -41,6 +43,8 @@ jobs:
steps:
- name: 'Checkout'
uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build'
run: |
pip install ninja pybind11 && \
......@@ -66,6 +70,8 @@ jobs:
steps:
- name: 'Checkout'
uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build'
run: |
pip install ninja pybind11 && \
......
[submodule "3rdparty/googletest"]
path = 3rdparty/googletest
url = https://github.com/google/googletest.git
[submodule "3rdparty/cudnn-frontend"]
path = 3rdparty/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
Subproject commit e7f64390e9bb4a3db622ffe11c973834f572b609
......@@ -138,3 +138,25 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
========================
cudnn-frontend
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
..
Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
fused_attn.h
============
.. doxygenfile:: fused_attn.h
......@@ -17,6 +17,7 @@ directly from C/C++, without Python.
activation.h <activation>
cast.h <cast>
gemm.h <gemm>
fused_attn.h <fused_attn>
layer_norm.h <layer_norm>
softmax.h <softmax>
transformer_engine.h <transformer_engine>
......
......@@ -14,6 +14,8 @@ Prerequisites
1. Linux x86_64
2. `CUDA 11.8 <https://developer.nvidia.com/cuda-downloads>`__
3. |driver link|_ supporting CUDA 11.8 or later.
4. `cuDNN 8 <https://developer.nvidia.com/cudnn>`__ or later.
5. For FP8 fused attention, `CUDA 12.1 <https://developer.nvidia.com/cuda-downloads>`__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9 <https://developer.nvidia.com/cudnn>`__ or later.
Transformer Engine in NGC Containers
......
......@@ -105,6 +105,7 @@ framework = os.environ.get("NVTE_FRAMEWORK", "pytorch")
include_dirs = [
"transformer_engine/common/include",
"transformer_engine/pytorch/csrc",
"3rdparty/cudnn-frontend/include",
]
if NVTE_WITH_USERBUFFERS:
if MPI_HOME:
......
......@@ -42,6 +42,7 @@ const std::string &typeName(DType type) {
static const std::unordered_map<DType, std::string> name_map = {
{DType::kByte, "byte"},
{DType::kInt32, "int32"},
{DType::kInt64, "int64"},
{DType::kFloat32, "float32"},
{DType::kFloat16, "float16"},
{DType::kBFloat16, "bfloat16"},
......
......@@ -44,6 +44,7 @@ struct BytesToType<8> {
using byte = uint8_t;
using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
......@@ -54,6 +55,7 @@ template <typename T>
struct TypeInfo{
using types = std::tuple<byte,
int32,
int64,
fp32,
fp16,
bf16,
......@@ -211,6 +213,12 @@ bool isFp8Type(DType type);
{__VA_ARGS__} \
} \
break; \
case DType::kInt64: \
{ \
using type = int64; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat32: \
{ \
using type = float; \
......
......@@ -19,7 +19,9 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
endif()
list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake/")
find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
find_package(CUDNN REQUIRED cudnn)
find_package(Python COMPONENTS Interpreter Development REQUIRED)
include_directories(${PROJECT_SOURCE_DIR})
......
add_library(CUDNN::cudnn_all INTERFACE IMPORTED)
find_path(
CUDNN_INCLUDE_DIR cudnn.h
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_INCLUDE_DIRS}
PATH_SUFFIXES include
)
function(find_cudnn_library NAME)
string(TOUPPER ${NAME} UPPERCASE_NAME)
find_library(
${UPPERCASE_NAME}_LIBRARY ${NAME}
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_LIBRARY_DIR}
PATH_SUFFIXES lib64 lib/x64 lib
)
if(${UPPERCASE_NAME}_LIBRARY)
add_library(CUDNN::${NAME} UNKNOWN IMPORTED)
set_target_properties(
CUDNN::${NAME} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}
IMPORTED_LOCATION ${${UPPERCASE_NAME}_LIBRARY}
)
message(STATUS "${NAME} found at ${${UPPERCASE_NAME}_LIBRARY}.")
else()
message(STATUS "${NAME} not found.")
endif()
endfunction()
find_cudnn_library(cudnn)
find_cudnn_library(cudnn_adv_infer)
find_cudnn_library(cudnn_adv_train)
find_cudnn_library(cudnn_cnn_infer)
find_cudnn_library(cudnn_cnn_train)
find_cudnn_library(cudnn_ops_infer)
find_cudnn_library(cudnn_ops_train)
include (FindPackageHandleStandardArgs)
find_package_handle_standard_args(
CUDNN REQUIRED_VARS
CUDNN_INCLUDE_DIR CUDNN_LIBRARY
)
if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY)
message(STATUS "cuDNN: ${CUDNN_LIBRARY}")
message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}")
set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found")
else()
set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Not Found")
endif()
target_include_directories(
CUDNN::cudnn_all
INTERFACE
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>
)
target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv_train
CUDNN::cudnn_ops_train
CUDNN::cudnn_cnn_train
CUDNN::cudnn_adv_infer
CUDNN::cudnn_cnn_infer
CUDNN::cudnn_ops_infer
CUDNN::cudnn
)
......@@ -12,6 +12,9 @@ list(APPEND transformer_engine_SOURCES
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu
......@@ -30,9 +33,11 @@ target_include_directories(transformer_engine PUBLIC
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDA::nvToolsExt)
CUDA::nvToolsExt
CUDNN::cudnn)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CMAKE_SOURCE_DIR}/../3rdparty/cudnn-frontend/include")
# Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/fused_attn.h"
#include "../common.h"
#include "utils.h"
#include "fused_attn_fp8.h"
// NVTE fused attention FWD FP8 with packed QKV
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
const NVTETensor cu_seqlens,
const NVTETensor rng_state,
size_t max_seqlen,
bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_Bias = reinterpret_cast<const Tensor*>(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor*>(S);
Tensor *output_O = reinterpret_cast<Tensor*>(O);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d]
size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = input_QKV->data.shape[2];
size_t d = input_QKV->data.shape[3];
const DType QKV_type = input_QKV->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8900)
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
// FP8 API doesn't use input_Bias, bias_type or attn_mask_type
fused_attn_fwd_fp8_qkvpacked(
b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout,
input_QKV, input_output_S, output_O,
Aux_Output_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n");
#endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n");
} else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
}
}
// NVTE fused attention BWD FP8 with packed QKV
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV,
const NVTETensor dBias,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQKV,
const NVTETensor cu_seqlens,
size_t max_seqlen,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_dBias = reinterpret_cast<const Tensor*>(dBias);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor*>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor*>(dP);
Tensor *output_dQKV = reinterpret_cast<Tensor*>(dQKV);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d]
size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = input_QKV->data.shape[2];
size_t d = input_QKV->data.shape[3];
const DType QKV_type = input_QKV->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8900)
// Aux_CTX_Tensors contain [M, ZInv, rng_state] generated by the forward pass
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
// FP8 API doesn't use input_dBias, bias_type or attn_mask_type
fused_attn_bwd_fp8_qkvpacked(
b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout,
input_QKV, input_O, input_dO,
input_M, input_ZInv,
input_S, input_output_dP,
output_dQKV,
input_cu_seqlens,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n");
#endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n");
} else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
}
}
// NVTE fused attention FWD FP8 with packed KV
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q,
const NVTETensor KV,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV);
const Tensor *input_Bias = reinterpret_cast<const Tensor*>(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor*>(S);
Tensor *output_O = reinterpret_cast<Tensor*>(O);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d]
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[1];
size_t d = input_Q->data.shape[2];
const DType QKV_type = input_Q->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n");
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
}
}
// NVTE fused attention BWD FP8 with packed KV
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q,
const NVTETensor KV,
const NVTETensor dBias,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQ,
NVTETensor dKV,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV);
const Tensor *input_dBias = reinterpret_cast<const Tensor*>(dBias);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor*>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor*>(dP);
Tensor *output_dQ = reinterpret_cast<Tensor*>(dQ);
Tensor *output_dKV = reinterpret_cast<Tensor*>(dKV);
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d]
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[1];
size_t d = input_Q->data.shape[2];
const DType QKV_type = input_Q->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n");
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
}
}
This diff is collapsed.
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fwd_fp8_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_Output_Tensors,
const Tensor *cu_seqlens,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
// fused attention BWD FP8 with packed QKV
void fused_attn_bwd_fp8_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
const Tensor *input_QKV,
const Tensor *input_O,
const Tensor *input_dO,
const Tensor *input_M,
const Tensor *input_ZInv,
const Tensor *input_S,
Tensor *input_output_dP,
const Tensor *output_dQKV,
const Tensor *cu_seqlens,
const Tensor *rng_state,
Tensor *workspace,
cudaStream_t stream,
cudnnHandle_t handle);
#endif // end of CUDNN>=8900
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/fused_attn.h"
#include "../common.h"
#include "utils.h"
namespace transformer_engine {
namespace fused_attn {
using namespace transformer_engine;
// get matrix strides based on matrix type
void generateMatrixStrides(
int64_t b, int64_t h,
int64_t s_q, int64_t s_kv,
int64_t d, int64_t* strideA,
NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) {
constexpr int batch_dim_idx = 0;
constexpr int head_dim_idx = 1;
constexpr int seqlen_dim_idx = 2;
constexpr int hidden_dim_idx = 3;
constexpr int seqlen_transpose_dim_idx = 3;
constexpr int hidden_transpose_dim_idx = 2;
constexpr int seqlen_q_dim_idx = 2;
constexpr int seqlen_kv_dim_idx = 3;
switch (matrix) {
case NVTE_QKV_Matrix::NVTE_Q_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * 3 * h * d;
} else {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_K_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else {
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = 3 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else {
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_V_Matrix:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = 2* h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else {
strideA[hidden_dim_idx] = 1;
strideA[seqlen_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose:
if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = 3 * h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = 2* h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else {
strideA[hidden_transpose_dim_idx] = 1;
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d;
}
break;
case NVTE_QKV_Matrix::NVTE_S_Matrix:
strideA[seqlen_kv_dim_idx] = 1;
strideA[seqlen_q_dim_idx] = s_kv;
strideA[head_dim_idx] = s_q * s_kv;
strideA[batch_dim_idx] = h * s_q * s_kv;
break;
case NVTE_QKV_Matrix::NVTE_O_Matrix:
strideA[seqlen_kv_dim_idx] = 1;
strideA[seqlen_q_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_q * h * d;
break;
}
}
// convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
int32_t *qkv_ragged_offset, int32_t *o_ragged_offset) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b) {
actual_seqlens_q[tid] = cu_seqlens_q[tid + 1] - cu_seqlens_q[tid];
}
if (tid < b + 1) {
qkv_ragged_offset[tid] = cu_seqlens_q[tid] * 3 * h * d;
o_ragged_offset[tid] = cu_seqlens_q[tid] * h * d;
}
}
} // namespace fused_attn
// get cuDNN data type
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return CUDNN_DATA_HALF;
case DType::kFloat32:
return CUDNN_DATA_FLOAT;
case DType::kBFloat16:
return CUDNN_DATA_BFLOAT16;
case DType::kFloat8E4M3:
return CUDNN_DATA_FP8_E4M3;
case DType::kFloat8E5M2:
return CUDNN_DATA_FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#include "transformer_engine/transformer_engine.h"
#include <cudnn_frontend.h>
namespace transformer_engine {
namespace fused_attn {
using namespace transformer_engine;
enum NVTE_QKV_Matrix {
NVTE_Q_Matrix = 0, // queries
NVTE_K_Matrix = 1, // keys
NVTE_K_Matrix_Transpose = 2, // keys transposed
NVTE_V_Matrix = 3, // values
NVTE_V_Matrix_Transpose = 4, // value matrix transposed
NVTE_S_Matrix = 5, // output of GEMM1
NVTE_O_Matrix = 6, // final output
};
void generateMatrixStrides(
int64_t b, int64_t h,
int64_t s_q, int64_t s_kv,
int64_t d, int64_t* strideA,
NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix);
struct FADescriptor {
std::int64_t b;
std::int64_t h;
std::int64_t s_q;
std::int64_t s_kv;
std::int64_t d;
float attnScale;
bool isTraining;
float dropoutProbability;
NVTE_QKV_Layout layout;
cudnnDataType_t tensor_type;
bool operator<(const FADescriptor &rhs) const {
return std::tie(b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability,
layout, tensor_type) < std::tie(
rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d,
rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.tensor_type);
}
};
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
int32_t *qkv_ragged_offset, int32_t *o_ragged_offset);
} // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
class cudnnExecutionPlanManager {
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}
cudnnHandle_t GetCudnnHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}
~cudnnExecutionPlanManager() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] {
if (handle_ != nullptr) {
cudnnDestroy(handle_);
}});
}
private:
cudnnHandle_t handle_ = nullptr;
};
} // namespace transformer_engine
#endif
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
enum NVTE_QKV_Layout {
/*!< separate Q, K, V tensors:
Q: [total_seqs_q, num_heads, head_dim]
| Q Q Q ... Q
| \___________ _____________/
total_seqs_q <| \/
| num_heads * head_dim
K: [total_seqs_kv, num_heads, head_dim]
| K K K ... K
| \___________ _____________/
total_seqs_kv <| \/
| num_heads * head_dim
V: [total_seqs_kv, num_heads, head_dim]
| V V V ... V
| \___________ _____________/
total_seqs_kv <| \/
| num_heads * head_dim
*/
NVTE_NOT_INTERLEAVED = 0,
/*!< packed QKV tensor:
QKV: [total_seqs, 3, num_heads, head_dim]
| Q Q Q ... Q K K K ... K V V V ... V
| \___________ _____________/
total_seqs <| \/
| num_heads * head_dim
*/
NVTE_QKV_INTERLEAVED = 1,
/*!< Q and packed KV tensor:
Q: [total_seqs_q, num_heads, head_dim]
| Q Q Q ... Q
| \___________ _____________/
total_seqs_q <| \/
| num_heads * head_dim
KV: [total_seqs_kv, 2, num_heads, head_dim]
| K K K ... K V V V ... V
| \___________ _____________/
total_seqs_kv <| \/
| num_heads * head_dim
*/
NVTE_KV_INTERLEAVED = 2
};
enum NVTE_Bias_Type {
NVTE_NO_BIAS = 0, /*!< no bias */
NVTE_PRE_SCALE_BIAS = 1, /*!< bias before scale */
NVTE_POST_SCALE_BIAS = 2 /*!< bias after scale */
};
enum NVTE_Mask_Type {
NVTE_PADDING_MASK = 0, /*!< padding attention mask */
NVTE_CAUSAL_MASK = 1, /*!< causal attention mask */
NVTE_NO_MASK = 2 /*!< no masking */
};
/*! \brief Compute dot product attention with packed QKV input.
*
* Computes:
* - P = Q * K.T + Bias
* - S = ScaleMaskSoftmax(P)
* - D = Dropout(S)
* - O = D * V.T
*
* Support Matrix:
* | precision | qkv layout | bias | mask | sequence length | head_dim |
* | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | <= 512 | 64 |
*
*
* \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim].
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens).
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
const NVTETensor cu_seqlens,
const NVTETensor rng_state,
size_t max_seqlen,
bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
* Support Matrix:
* | precision | qkv layout | bias | mask | sequence length | head_dim |
* | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | <= 512 | 64 |
*
*
* \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim].
* \param[in] dBias The gradient of the Bias tensor.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens).
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV,
const NVTETensor dBias,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQKV,
const NVTETensor cu_seqlens,
size_t max_seqlen,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input.
*
* Computes:
* - P = Q * K.T + Bias
* - S = ScaleMaskSoftmax(P)
* - D = Dropout(S)
* - O = D * V.T
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(cu_seqlens_kv).
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q,
const NVTETensor KV,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] dBias The gradient of the Bias tensor.
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode.
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(cu_seqlens_kv).
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q,
const NVTETensor KV,
const NVTETensor dBias,
const NVTETensor O,
const NVTETensor dO,
const NVTETensor S,
NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors,
NVTETensor dQ,
NVTETensor dKV,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
NVTETensor workspace,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif
......@@ -9,6 +9,7 @@
#include <cuda_runtime_api.h>
#include <cublas_v2.h>
#include <cudnn.h>
#include <string>
#include <stdexcept>
......@@ -39,10 +40,18 @@ inline void check_cublas_(cublasStatus_t status) {
}
}
inline void check_cudnn_(cudnnStatus_t status) {
if ( status != CUDNN_STATUS_SUCCESS ) {
NVTE_ERROR("CUDNN Error: " + std::string(cudnnGetErrorString(status)));
}
}
} // namespace
#define NVTE_CHECK_CUDA(ans) { check_cuda_(ans); }
#define NVTE_CHECK_CUBLAS(ans) { check_cublas_(ans); }
#define NVTE_CHECK_CUDNN(ans) { check_cudnn_(ans); }
#endif // TRANSFORMER_ENGINE_LOGGING_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment