"vscode:/vscode.git/clone" did not exist on "31348dff03d638eb66abda9bec94b8992de9c7a1"
Unverified Commit 989a53a0 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Add FP8 fused attention (#155)



* Add FP8 fused attention to TE for PyTorch
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add license for cudnn-frontend, modify installation requirements, and refactor some headers for aesthetics
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add c api docs for fused attention
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add exception for unsupported precision/sequence length combinations
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix installation requirement for non fused attn use cases
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix docs for fused-attn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* prefix enums with NVTE_ and replace old MHA_Matrix with NVTE_QKV_Matrix
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* minor fixes based on PR comments
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix description for kvpacked fwd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix description of Bias in C api
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* minor fixes for cudnn requirement and description for QKV tensors
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix QKV layout description and support matrix for C api
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* add asserts to cpp_extensions for qkv layout/bias type/attn mask type
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix typo precision
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c3407300
......@@ -24,11 +24,12 @@ extern "C" {
enum NVTEDType {
kNVTEByte = 0, /*!< Byte */
kNVTEInt32 = 1, /*!< 32-bit integer */
kNVTEFloat32 = 2, /*!< 32-bit float */
kNVTEFloat16 = 3, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 4, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 5, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 6, /*!< 8-bit float (E5M2) */
kNVTEInt64 = 2, /*!< 32-bit integer */
kNVTEFloat32 = 3, /*!< 32-bit float */
kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */
kNVTENumTypes /*!< Number of supported types */
};
......@@ -129,6 +130,19 @@ float *nvte_tensor_scale(const NVTETensor tensor);
*/
float *nvte_tensor_scale_inv(const NVTETensor tensor);
struct NVTETensorPack {
static const int MAX_SIZE = 10; /*!< we expect <10 matrices in auxiliary outputs */
NVTETensor tensors[MAX_SIZE]; /*!< wrappers to tensors, do not hold memory */
size_t size = 0; /*!< actual size of the tensor pack, 0 <= size <= MAX_SIZE */
};
/*! \brief Create NVTETensors in NVTETensorPack.
*/
void nvte_tensor_pack_create(NVTETensorPack* pack);
/*! \brief Destroy NVTETensors in NVTETensorPack.
*/
void nvte_tensor_pack_destroy(NVTETensorPack* pack);
#ifdef __cplusplus
} // extern "C"
......@@ -146,11 +160,12 @@ namespace transformer_engine {
enum class DType {
kByte = 0,
kInt32 = 1,
kFloat32 = 2,
kFloat16 = 3,
kBFloat16 = 4,
kFloat8E4M3 = 5,
kFloat8E5M2 = 6,
kInt64 = 2,
kFloat32 = 3,
kFloat16 = 4,
kBFloat16 = 5,
kFloat8E4M3 = 6,
kFloat8E5M2 = 7,
kNumTypes
};
......
......@@ -133,3 +133,16 @@ float *nvte_tensor_scale_inv(const NVTETensor tensor) {
"Tensor's inverse of scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale_inv.dptr);
}
void nvte_tensor_pack_create(NVTETensorPack* pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor);
}
}
void nvte_tensor_pack_destroy(NVTETensorPack* pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
auto *t = reinterpret_cast<transformer_engine::Tensor*>(pack->tensors[i]);
delete t;
}
}
......@@ -14,7 +14,7 @@ extension. Has one to one mapping
with enum in transformer_engine.h
"""
TE_DType = {
torch.int8: tex.DType.kByte,
torch.uint8: tex.DType.kByte,
torch.int32: tex.DType.kInt32,
torch.float32: tex.DType.kFloat32,
torch.half: tex.DType.kFloat16,
......
......@@ -88,6 +88,19 @@ size_t product(const std::vector<size_t> &shape) {
}
at::Tensor allocateSpace(const std::vector<size_t>& shape,
const transformer_engine::DType type,
bool init_to_zeros) {
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
c10::IntArrayRef ar_shape(shape_int64);
if (init_to_zeros) {
return at::zeros(ar_shape, at::CUDA(GetATenDType(type)));
} else {
return at::empty(ar_shape, at::CUDA(GetATenDType(type)));
}
}
at::Tensor allocateSpace(const NVTEShape &shape,
const transformer_engine::DType type,
bool init_to_zeros) {
......
......@@ -15,9 +15,15 @@
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/softmax.h>
#include <transformer_engine/fused_attn.h>
#include <ATen/ATen.h>
#include <ATen/cudnn/Handle.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/macros/Macros.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <torch/extension.h>
#include <torch/torch.h>
#include <cuda.h>
......@@ -101,6 +107,12 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
return transformer_engine::DType::kBFloat16;
case at::kBool:
return transformer_engine::DType::kByte;
case torch::kByte:
return transformer_engine::DType::kByte;
case torch::kInt32:
return transformer_engine::DType::kInt32;
case torch::kInt64:
return transformer_engine::DType::kInt64;
default:
NVTE_ERROR("Invalid type");
}
......@@ -141,6 +153,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor,
size_t product(const std::vector<size_t> &shape);
at::Tensor allocateSpace(const std::vector<size_t>& shape,
const transformer_engine::DType type,
bool init_to_zeros);
at::Tensor allocateSpace(const NVTEShape &shape,
const transformer_engine::DType type,
......
......@@ -5,7 +5,95 @@
************************************************************************/
#include "common.h"
#include "../common.h"
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout);
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type);
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias);
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor KV,
const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen);
std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
const at::Tensor KV,
const at::Tensor O,
const at::Tensor dO,
const transformer_engine::DType qkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV,
const c10::optional<at::Tensor> dBias);
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
......
......@@ -102,7 +102,7 @@ def get_workspace() -> torch.Tensor:
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.int8, device="cuda"
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
)
return _cublas_workspace
......@@ -520,7 +520,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.int8,
dtype=torch.uint8,
),
)
setattr(
......@@ -530,7 +530,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.int8,
dtype=torch.uint8,
),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment