"vscode:/vscode.git/clone" did not exist on "e6fbcc8271ebd8cdb4fbb45c1fdb7c298a4f1282"
Unverified Commit a150d286 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C/PyTorch] Add workspace optimization for fused attention arbitrary_seqlen backend (#396)



* add workspace optimization for arbitrary_seqlen fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix whitespace for lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add use_workspace_opt to cudnn plan cache and fix workspace estimate
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* modify workspace opt logic; move zero fill to FP8 API only; other minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix try/catch
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix std string error when input is nullptr
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Add = for required vs allowed workspace comparison
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent 4e37499b
...@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if "pytorch" in frameworks(): if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.4"]) add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.2.1"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks(): if "jax" in frameworks():
if not found_pybind11(): if not found_pybind11():
......
...@@ -107,7 +107,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -107,7 +107,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
return backend; return backend;
} }
// NVTE fused attention FWD FP8 with packed QKV // NVTE fused attention FWD with packed QKV
void nvte_fused_attn_fwd_qkvpacked( void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV, const NVTETensor QKV,
const NVTETensor Bias, const NVTETensor Bias,
...@@ -192,7 +192,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -192,7 +192,7 @@ void nvte_fused_attn_fwd_qkvpacked(
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
// NVTE fused attention BWD FP8 with packed QKV // NVTE fused attention BWD with packed QKV
void nvte_fused_attn_bwd_qkvpacked( void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV, const NVTETensor QKV,
const NVTETensor O, const NVTETensor O,
...@@ -291,7 +291,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -291,7 +291,7 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
// NVTE fused attention FWD FP8 with packed KV // NVTE fused attention FWD with packed KV
void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor Q,
const NVTETensor KV, const NVTETensor KV,
...@@ -361,7 +361,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -361,7 +361,7 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
// NVTE fused attention BWD FP8 with packed KV // NVTE fused attention BWD with packed KV
void nvte_fused_attn_bwd_kvpacked( void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor Q,
const NVTETensor KV, const NVTETensor KV,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "../common.h" #include "../common.h"
#include "utils.h" #include "utils.h"
#include "../util/cuda_runtime.h"
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
#define Q_ID 1 #define Q_ID 1
...@@ -555,7 +556,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -555,7 +556,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
d, scaling_factor, d, scaling_factor,
is_training, dropout_probability, is_training, dropout_probability,
layout, NVTE_Bias_Type::NVTE_NO_BIAS, layout, NVTE_Bias_Type::NVTE_NO_BIAS,
NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType}; NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType,
false};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_fprop_cache; static thread_local CacheType fmha_fprop_cache;
...@@ -677,7 +679,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -677,7 +679,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO,
void* devPtrDropoutSeed, void* devPtrDropoutOffset, void* devPtrDropoutSeed, void* devPtrDropoutOffset,
cudnnDataType_t tensorType, void *workspace, size_t *workspace_size, cudnnDataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle, bool use_workspace_opt) {
try { try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
...@@ -686,7 +688,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -686,7 +688,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
d, scaling_factor, d, scaling_factor,
true, dropout_probability, true, dropout_probability,
layout, NVTE_Bias_Type::NVTE_NO_BIAS, layout, NVTE_Bias_Type::NVTE_NO_BIAS,
NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType}; NVTE_Mask_Type::NVTE_CAUSAL_MASK, tensorType,
use_workspace_opt};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_bprop_cache; static thread_local CacheType fmha_bprop_cache;
...@@ -1039,7 +1042,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -1039,7 +1042,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
ops.push_back(std::move(reshape_op2)); ops.push_back(std::move(reshape_op2));
/******************************************************************************* /*******************************************************************************
* dP @ K -> dqAccumTensor */ * dP @ K -> dqAccumTensor / dqTensor */
auto dqAccumTensor = cudnn_frontend::TensorBuilder() auto dqAccumTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dqAccum_dim) .setDim(4, dqAccum_dim)
...@@ -1056,6 +1059,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -1056,6 +1059,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder() auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT) .setComputeType(CUDNN_DATA_FLOAT)
.build(); .build();
if (!use_workspace_opt) {
auto matmul_op3 = cudnn_frontend::OperationBuilder( auto matmul_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dPScaledTensor) .setaMatDesc(dPScaledTensor)
...@@ -1065,6 +1069,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -1065,6 +1069,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.build(); .build();
ops.push_back(std::move(matmul_op3)); ops.push_back(std::move(matmul_op3));
} else {
auto matmul_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dPScaledTensor)
.setbMatDesc(kTensor)
.setcMatDesc(dQTensor)
.setmatmulDesc(matmul_3_Desc)
.build();
ops.push_back(std::move(matmul_op3));
}
/******************************************************************************* /*******************************************************************************
* dP.T @ Q -> dK */ * dP.T @ Q -> dK */
...@@ -1095,9 +1110,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -1095,9 +1110,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
/******************************************************************************* /*******************************************************************************
* dqAccumTensor @ identity -> dqTensor */ * dqAccumTensor @ identity -> dqTensor */
if (!use_workspace_opt) {
auto identityDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_IDENTITY); auto identityDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_IDENTITY);
auto identity_op = unary_pw_op_create(dqAccumTensor, dQTensor, identityDesc); auto identity_op = unary_pw_op_create(dqAccumTensor, dQTensor, identityDesc);
ops.push_back(std::move(identity_op)); ops.push_back(std::move(identity_op));
}
for (unsigned int i = 0; i < ops.size(); i++) { for (unsigned int i = 0; i < ops.size(); i++) {
all_ops.push_back(&ops[i]); all_ops.push_back(&ops[i]);
...@@ -1136,22 +1153,32 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -1136,22 +1153,32 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
size_t softmaxSum_workspace_size = b * h * s_q * sizeof(float); size_t softmaxSum_workspace_size = b * h * s_q * sizeof(float);
size_t dqAccum_workspace_size = b * s_q * h * d * sizeof(float); size_t dqAccum_workspace_size = b * s_q * h * d * sizeof(float);
if (workspace == nullptr) { if (workspace == nullptr) {
if (use_workspace_opt) {
*workspace_size = plan_workspace_size + softmaxSum_workspace_size;
} else {
*workspace_size = plan_workspace_size + softmaxSum_workspace_size *workspace_size = plan_workspace_size + softmaxSum_workspace_size
+ dqAccum_workspace_size; + dqAccum_workspace_size;
}
return; return;
} }
void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size; void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum) void *devPtrdQAccumulator = nullptr;
if (!use_workspace_opt) {
devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum)
+ softmaxSum_workspace_size; + softmaxSum_workspace_size;
NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQAccumulator, 0, dqAccum_workspace_size, stream)); NVTE_CHECK_CUDA(cudaMemsetAsync(
devPtrdQAccumulator, 0, dqAccum_workspace_size, stream));
}
std::set<std::pair<uint64_t, void *>> data_ptrs; std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack // add all the data pointers to be used in the variant pack
float negInfinity = -1.0E+10f; float negInfinity = -1.0E+10f;
float scale_dropout = 1.0f/(1.0f - dropout_probability); float scale_dropout = 1.0f/(1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ID, devPtrdQ)); data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ID, devPtrdQ));
if (!use_workspace_opt) {
data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ACCUM_ID, devPtrdQAccumulator)); data_ptrs.insert(std::pair<uint64_t, void*>(dQ_ACCUM_ID, devPtrdQAccumulator));
}
data_ptrs.insert(std::pair<uint64_t, void*>(dK_ID, devPtrdK)); data_ptrs.insert(std::pair<uint64_t, void*>(dK_ID, devPtrdK));
data_ptrs.insert(std::pair<uint64_t, void*>(dV_ID, devPtrdV)); data_ptrs.insert(std::pair<uint64_t, void*>(dV_ID, devPtrdV));
...@@ -1298,13 +1325,44 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, ...@@ -1298,13 +1325,44 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
const auto qkv_type = input_QKV->data.dtype; const auto qkv_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
bool use_workspace_opt = false;
#if (CUDNN_VERSION >= 8905)
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
if (sm_arch_ >= 90) {
// quick estimate of dp workspace size
size_t max_seqlen_div_up_q = ((max_seqlen + 64 - 1) / 64) * 64;
size_t max_seqlen_div_up_kv = ((max_seqlen + 64 - 1) / 64) * 64;
size_t required_dp_workspace =
(batch * num_head * max_seqlen_div_up_q * max_seqlen_div_up_kv * 2 + 1048576 - 1) / 1048576;
// default upper limit for dp workspace 256MB
size_t max_allowed_dp_workspace = 256;
const char* env_workspace_limit_char = std::getenv("NVTE_FUSED_ATTN_DP_WORKSPACE_LIMIT");
if (env_workspace_limit_char != nullptr) {
try {
std::string env_dp_workspace_limit(env_workspace_limit_char);
int dp_workspace_limit = std::stoi(env_dp_workspace_limit);
if (dp_workspace_limit > max_allowed_dp_workspace) {
max_allowed_dp_workspace = dp_workspace_limit;
}
} catch (...) {
NVTE_ERROR(
"Invalid argument for NVTE_FUSED_ATTN_DP_WORKSPACE_LIMIT (integer; in MBytes)! \n");
}
}
if (required_dp_workspace <= max_allowed_dp_workspace) {
use_workspace_opt = true;
}
}
#endif
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, fused_attn_arbitrary_seqlen_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim,
attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
get_cudnn_dtype(qkv_type), get_cudnn_dtype(qkv_type), workspace->data.dptr,
workspace->data.dptr, &workspace_size, stream, handle); &workspace_size, stream, handle, use_workspace_opt);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
......
...@@ -647,7 +647,8 @@ void fused_attn_max_512_fwd_impl( ...@@ -647,7 +647,8 @@ void fused_attn_max_512_fwd_impl(
d, scaling_factor, d, scaling_factor,
is_training, dropout_probability, is_training, dropout_probability,
layout, bias_type, layout, bias_type,
mask_type, tensorType}; mask_type, tensorType,
false};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_fprop_cache; static thread_local CacheType fmha_fprop_cache;
...@@ -846,7 +847,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -846,7 +847,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability, b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability,
layout, bias_type, mask_type, tensorType}; layout, bias_type, mask_type, tensorType, false};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_bprop_cache; static thread_local CacheType fmha_bprop_cache;
......
...@@ -1013,7 +1013,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in ...@@ -1013,7 +1013,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability, layout, attnScale, isTraining, dropoutProbability, layout,
NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType}; NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType, false};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fa_fprop_cache; static thread_local CacheType fa_fprop_cache;
...@@ -1329,7 +1329,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in ...@@ -1329,7 +1329,7 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, in
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, b, h, s_q, s_kv, d,
attnScale, false, dropoutProbability, layout, attnScale, false, dropoutProbability, layout,
NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType}; NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType, false};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fa_bprop_cache; static thread_local CacheType fa_bprop_cache;
......
...@@ -80,16 +80,18 @@ struct FADescriptor { ...@@ -80,16 +80,18 @@ struct FADescriptor {
NVTE_Bias_Type bias_type; NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type; NVTE_Mask_Type mask_type;
cudnnDataType_t tensor_type; cudnnDataType_t tensor_type;
bool use_workspace_opt;
bool operator<(const FADescriptor &rhs) const { bool operator<(const FADescriptor &rhs) const {
return std::tie(b, h, s_q, s_kv, d, return std::tie(b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability, attnScale, isTraining, dropoutProbability,
layout, mask_type, bias_type, tensor_type) layout, mask_type, bias_type, tensor_type, use_workspace_opt)
< std::tie( < std::tie(
rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d, rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d,
rhs.attnScale, rhs.isTraining, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type, rhs.tensor_type); rhs.mask_type, rhs.bias_type,
rhs.tensor_type, rhs.use_workspace_opt);
} }
}; };
......
...@@ -455,7 +455,7 @@ def _check_kv_layout(k, v): ...@@ -455,7 +455,7 @@ def _check_kv_layout(k, v):
class FlashAttention(torch.nn.Module): class FlashAttention(torch.nn.Module):
"""Dot product attention, using HazyResearch flash-attn package: """Dot product attention, using HazyResearch flash-attn package:
https://github.com/HazyResearch/flash-attention https://github.com/Dao-AILab/flash-attention
""" """
def __init__( def __init__(
...@@ -709,7 +709,7 @@ class FusedAttention(torch.nn.Module): ...@@ -709,7 +709,7 @@ class FusedAttention(torch.nn.Module):
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.attention_type = attention_type self.attention_type = attention_type
self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "1") == "1" self.use_FAv2_bwd = (os.getenv("NVTE_FUSED_ATTN_USE_FAv2_BWD", "0") == "1"
and _flash_attn_2_available and _flash_attn_2_available
and get_device_compute_capability() == 9.0) and get_device_compute_capability() == 9.0)
...@@ -1055,14 +1055,27 @@ class DotProductAttention(torch.nn.Module): ...@@ -1055,14 +1055,27 @@ class DotProductAttention(torch.nn.Module):
.. note:: .. note::
`DotProductAttention` supports three backends: 1) `FlashAttention` which calls DotProductAttention supports three backends: 1) FlashAttention which calls
HazyResearch's FlashAttention PyTorch API, 2) `FusedAttention` which has multiple HazyResearch/Dao-AILab's `flash-attn <https://arxiv.org/pdf/2305.13245.pdf>`_
fused attention implementations as its backends (see `FusedAttention` for PyTorch API, 2) FusedAttention which has multiple fused attention implementations
more details), and 3) `UnfusedDotProductAttention` which is the native PyTorch based on `cuDNN Graph API
implementation with fused scaled masked softmax. Users can use environment variables <https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#op-fusion>`_
`NVTE_FLASH_ATTN`, `NVTE_FUSED_ATTN`, and `NVTE_FUSED_ATTN_BACKEND` to control (see :attr:`FusedAttention` for more details on FusedAttention backends), and 3)
which DotProductAttention backend, and FusedAttention backend if applicable, to use. UnfusedDotProductAttention which is the native PyTorch implementation
The default DotProductAttention backend is 1. with fused scaled masked softmax.
.. note::
Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
and FusedAttention backend if applicable, to use. TransformerEngine prioritizes
FlashAttention over FusedAttention and over UnfusedDotProductAttention.
If FusedAttention is being used, users can also choose to switch to flash-attn's
implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
(default: 0), because of the performance differences between various versions of
flash-attn and FusedAttention. Further, :attr:`NVTE_FUSED_ATTN_DP_WORKSPACE_LIMIT`
can be used to enable the workspace related optimizations in FusedAttention
(default: 256MB; raise the limit to enable these performance optimizations).
Parameters Parameters
---------- ----------
...@@ -1085,7 +1098,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -1085,7 +1098,7 @@ class DotProductAttention(torch.nn.Module):
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`} Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
core_attention_bias: Optional[torch.Tensor], default = `None` core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T Bias tensor for Q * K.T
fast_zero_fill: bool, defautl = `True` fast_zero_fill: bool, default = `True`
Whether to use the fast path to set output tensors to 0 or not. Whether to use the fast path to set output tensors to 0 or not.
""" """
......
...@@ -108,16 +108,16 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -108,16 +108,16 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs), auto O = torch::empty({static_cast<int64_t>(total_seqs),
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options); static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
if (set_zero && (h * d % block_size == 0)) {
mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
}
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
if (set_zero && (h * d % block_size == 0)) {
mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
}
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) { || (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
...@@ -252,19 +252,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -252,19 +252,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// create output tensor dQKV // create output tensor dQKV
at::Tensor dQKV = torch::empty_like(QKV); at::Tensor dQKV = torch::empty_like(QKV);
auto max_tokens = dQKV.size(0);
auto self_2d = dQKV.view({max_tokens, -1});
auto fcd_size = self_2d.size(1);
if (set_zero && (fcd_size % block_size == 0)) {
mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
dQKV.fill_(0);
}
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias; at::Tensor dBias;
TensorWrapper te_dBias; TensorWrapper te_dBias;
if (bias_type != NVTE_NO_BIAS) { if (bias_type != NVTE_NO_BIAS) {
dBias = torch::zeros({1, static_cast<int64_t>(h), dBias = torch::empty({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen), static_cast<int64_t>(max_seqlen),
static_cast<int64_t>(max_seqlen)}, options); static_cast<int64_t>(max_seqlen)}, options);
te_dBias = makeTransformerEngineTensor(dBias); te_dBias = makeTransformerEngineTensor(dBias);
...@@ -274,6 +266,14 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -274,6 +266,14 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto max_tokens = dQKV.size(0);
auto self_2d = dQKV.view({max_tokens, -1});
auto fcd_size = self_2d.size(1);
if (set_zero && (fcd_size % block_size == 0)) {
mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
dQKV.fill_(0);
}
if ((!descale_QKV.has_value()) || (!descale_S.has_value()) if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value()) || (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value()) || (!scale_S.has_value()) || (!scale_dP.has_value())
...@@ -409,16 +409,16 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -409,16 +409,16 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
auto O = torch::empty({static_cast<int64_t>(total_seqs_q), auto O = torch::empty({static_cast<int64_t>(total_seqs_q),
static_cast<int64_t>(h), static_cast<int64_t>(d)}, options); static_cast<int64_t>(h), static_cast<int64_t>(d)}, options);
if (set_zero && (h * d % block_size == 0)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
}
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
if (set_zero && (h * d % block_size == 0)) {
mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
O.fill_(0);
}
if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value())
|| (!amax_S.has_value()) || (!amax_O.has_value())) { || (!amax_S.has_value()) || (!amax_O.has_value())) {
std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O";
...@@ -567,24 +567,11 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -567,24 +567,11 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
// create output tensors dQ and dKV // create output tensors dQ and dKV
at::Tensor dQ = torch::empty_like(Q); at::Tensor dQ = torch::empty_like(Q);
at::Tensor dKV = torch::empty_like(KV); at::Tensor dKV = torch::empty_like(KV);
auto max_tokens_q = dQ.size(0);
auto self_2d_q = dQ.view({max_tokens_q, -1});
auto fcd_size_q = self_2d_q.size(1);
auto max_tokens_kv = dQ.size(0);
auto self_2d_kv = dQ.view({max_tokens_kv, -1});
auto fcd_size_kv = self_2d_kv.size(1);
if (set_zero && (fcd_size_q % block_size == 0) && (fcd_size_kv % block_size == 0)) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
dQ.fill_(0);
dKV.fill_(0);
}
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias; at::Tensor dBias;
TensorWrapper te_dBias; TensorWrapper te_dBias;
if (bias_type != NVTE_NO_BIAS) { if (bias_type != NVTE_NO_BIAS) {
dBias = torch::zeros({1, static_cast<int64_t>(h), dBias = torch::empty({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen_q), static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options); static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias); te_dBias = makeTransformerEngineTensor(dBias);
...@@ -594,6 +581,19 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -594,6 +581,19 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto max_tokens_q = dQ.size(0);
auto self_2d_q = dQ.view({max_tokens_q, -1});
auto fcd_size_q = self_2d_q.size(1);
auto max_tokens_kv = dQ.size(0);
auto self_2d_kv = dQ.view({max_tokens_kv, -1});
auto fcd_size_kv = self_2d_kv.size(1);
if (set_zero && (fcd_size_q % block_size == 0) && (fcd_size_kv % block_size == 0)) {
mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
dQ.fill_(0);
dKV.fill_(0);
}
if ((!descale_QKV.has_value()) || (!descale_S.has_value()) if ((!descale_QKV.has_value()) || (!descale_S.has_value())
|| (!descale_O.has_value()) || (!descale_dO.has_value()) || (!descale_O.has_value()) || (!descale_dO.has_value())
|| (!scale_S.has_value()) || (!scale_dP.has_value()) || (!scale_S.has_value()) || (!scale_dP.has_value())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment