"git@developer.sourcefind.cn:wangkx1/easy_tools.git" did not exist on "932303a0df5aac339e92ae771a1383f69958ed59"
Unverified Commit 0707552e authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Fused attention fixes for cuDNN 8.9.3 (#311)



* Fix bprop for cuDNN 8.9.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update cuDNN version requirement to 8.9.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* debug paddle CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* debug paddle CI; force LD_LIBRARY
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* debug paddle CI; force LD_LIBRARY to /opt
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove debug info for paddle
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change cudnn requirement to 8.9.1 for v1 and 8.9.0 for v2; add batch size 32 for unit test; add LD library path for paddle tests temporarily
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove printf line in fused_attn.cpp
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add batch size 32 for unit test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update cudnn-frontend to 0.9.2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove temporary LD library path used for testing pre-released cudnn 8.9.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 58d2ebab
Subproject commit a4f05c1edcef453f5fd52f96218c29c7d420e511
Subproject commit 12f35fa2be5994c1106367cac2fba21457b064f4
......@@ -184,7 +184,7 @@ Pre-requisites
* CUDA 11.8 or later
* NVIDIA Driver supporting CUDA 11.8 or later
* cuDNN 8.1 or later
* For FP8 fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
From source
^^^^^^^^^^^
......
......@@ -30,17 +30,20 @@ class ModelConfig:
model_configs = {
"test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"),
"test2": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
"test3": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
"test4": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
"test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
"test2": ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal"),
"test3": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
"test4": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
"test5": ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal"),
"test6": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
"test7": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
"test8": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
}
param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)
batch_sizes = [1, 2]
batch_sizes = [1, 2, 32]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
......
......@@ -373,7 +373,7 @@ createDropoutForward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(
tensorType, D_CONST_ID, scale_dim,
CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(
......@@ -454,7 +454,7 @@ createDropoutBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(
tensorType, D_CONST_ID, scale_dim,
CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(
......@@ -738,6 +738,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
b, h, s_q, s_kv, d, o_stride,
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t dqAccum_dim[4] = {b, h, s_q, d};
int64_t dqAccum_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dqAccum_stride,
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};
......@@ -770,19 +775,19 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto afterReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, reduction_dim,
reduction_stride, true, false); // is virtual
auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder()
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
// Create a reduction max node
auto reductionMax_op = cudnn_frontend::OperationBuilder(
// Create a reduction add node
auto reductionAdd_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(dotProductTensor)
.setyDesc(afterReductionTensor)
.setreductionDesc(reductionMaxDesc)
.setreductionDesc(reductionAddDesc)
.build();
ops.push_back(std::move(reductionMax_op));
ops.push_back(std::move(reductionAdd_op));
/*******************************************************************************
......@@ -895,16 +900,25 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
ops.push_back(std::move(reshape_op));
// Outputs of bprop
int64_t dqkv_dim[4] = {b, h, s_kv, d};
int64_t dqkv_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, dqkv_stride,
int64_t dq_dim[4] = {b, h, s_q, d};
int64_t dq_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dq_stride,
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
int64_t dk_dim[4] = {b, h, s_kv, d};
int64_t dk_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dk_stride,
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
int64_t dv_dim[4] = {b, h, s_kv, d};
int64_t dv_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dv_stride,
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
// Outputs of backprop
auto dQTensor = tensor_create(tensorType, dQ_ID, dqkv_dim, dqkv_stride, false, false);
auto dKTensor = tensor_create(tensorType, dK_ID, dqkv_dim, dqkv_stride, false, false);
auto dVTensor = tensor_create(tensorType, dV_ID, dqkv_dim, dqkv_stride, false, false);
auto dQTensor = tensor_create(tensorType, dQ_ID, dq_dim, dq_stride, false, false);
auto dKTensor = tensor_create(tensorType, dK_ID, dk_dim, dk_stride, false, false);
auto dVTensor = tensor_create(tensorType, dV_ID, dv_dim, dv_stride, false, false);
// not virtual
/*******************************************************************************
......@@ -1028,8 +1042,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
* dP @ K -> dqAccumTensor */
auto dqAccumTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dqkv_dim)
.setStride(4, dqkv_stride)
.setDim(4, dqAccum_dim)
.setStride(4, dqAccum_stride)
.setId(dQ_ACCUM_ID)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(CUDNN_DATA_FLOAT)
......@@ -1044,7 +1058,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.build();
auto matmul_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dPTensor)
.setaMatDesc(dPScaledTensor)
.setbMatDesc(kTensor)
.setcMatDesc(dqAccumTensor)
.setmatmulDesc(matmul_3_Desc)
......@@ -1060,7 +1074,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
p_transpose_stride, true, false); // is virtual
auto reshape_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(dPTensor)
.setxDesc(dPScaledTensor)
.setyDesc(dPTransposeTensor)
.build();
ops.push_back(std::move(reshape_op3));
......@@ -1185,7 +1199,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = num_head * head_dim;
const auto stride = 2 * num_head * head_dim;
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
......@@ -1256,7 +1270,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
auto stride = num_head * head_dim;
auto stride = 2 * num_head * head_dim;
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
......
......@@ -680,17 +680,23 @@ void fused_attn_max_512_fwd_impl(
// inference mode doesn't need the S auxiliary
auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) ||
(mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && is_training;
std::shared_ptr<cudnn_frontend::Tensor> maskInput;
auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops);
NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS,
"NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS has not been implemented.");
if (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) {
createBias(b, h, s_q, s_kv, d, layout, tensorType, ops, bmm1_output);
auto bias_output = createBias(b, h, s_q, s_kv, d, layout,
tensorType, ops, bmm1_output);
maskInput = std::make_shared<cudnn_frontend::Tensor>(std::move(bias_output));
}
if (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) {
maskInput = std::make_shared<cudnn_frontend::Tensor>(std::move(bmm1_output));
}
auto mask_output = createMask(b, h, s_q, s_kv, d, layout, mask_type, tensorType, ops,
bmm1_output, false);
*maskInput.get(), false);
NVTE_CHECK(dropout_probability != 1.0f, "Dropout probability cannot be 1.0.");
......@@ -1248,7 +1254,7 @@ void fused_attn_max_512_fwd_qkvpacked(
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = num_head * head_dim;
const auto stride = 2 * num_head * head_dim;
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
......@@ -1322,7 +1328,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
void *devPtrQ = input_Q->data.dptr;
// KV shape is [b, s, 2, h, d]
const auto stride = num_head * head_dim;
const auto stride = 2 * num_head * head_dim;
void *devPtrK = input_KV->data.dptr;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
......@@ -1393,7 +1399,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
auto stride = num_head * head_dim;
auto stride = 2 * num_head * head_dim;
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
......@@ -1453,7 +1459,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
// Q shape is [b, s, h, d]
// KV shape is [b, s, 2, h, d]
auto stride = num_head * head_dim;
auto stride = 2 * num_head * head_dim;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_KV->data.dptr;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment