Commit 7b915a10 authored by danyao12's avatar danyao12
Browse files

bwd qloop 2 kernels update mask

parent 2018bd28
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 1 #define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
...@@ -268,7 +268,7 @@ int run(int argc, char* argv[]) ...@@ -268,7 +268,7 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 253; ck::index_t M = 512;
ck::index_t N = 512; ck::index_t N = 512;
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
......
...@@ -86,7 +86,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; ...@@ -86,7 +86,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
#else #else
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
...@@ -228,8 +228,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -228,8 +228,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2]; auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) { s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
......
...@@ -68,7 +68,7 @@ using GemmDataType = F16; ...@@ -68,7 +68,7 @@ using GemmDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using ZDataType = INT32; // U16 using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
using DDataType = F32; using DDataType = F32;
...@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; ...@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
#else #else
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
...@@ -227,8 +227,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -227,8 +227,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2]; auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) { s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
......
...@@ -359,12 +359,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -359,12 +359,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t Q_K1 = 8;
static constexpr index_t K_K1 = 8;
static constexpr index_t V_N1 = 2;
static constexpr index_t Q_M1 = 2;
static constexpr index_t K_N1 = 2;
static constexpr index_t V_O1 = 8; static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8; static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2; static constexpr index_t Y_M1 = 2;
...@@ -666,9 +660,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -666,9 +660,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{
return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -791,7 +789,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -791,7 +789,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
// GridwiseYDotYGrad // GridwiseYDotYGrad
...@@ -892,7 +890,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -892,7 +890,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)}, c0_matrix_mask_{a_grid_desc_g_m_k_.GetLength(I1), b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
......
...@@ -365,12 +365,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -365,12 +365,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t Q_K1 = 8;
static constexpr index_t K_K1 = 8;
static constexpr index_t V_N1 = 2;
static constexpr index_t Q_M1 = 2;
static constexpr index_t K_N1 = 2;
static constexpr index_t V_O1 = 8; static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8; static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2; static constexpr index_t Y_M1 = 2;
...@@ -672,9 +666,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -672,9 +666,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{
return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -805,7 +803,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -805,7 +803,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
// GridwiseYDotYGrad // GridwiseYDotYGrad
...@@ -905,7 +903,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -905,7 +903,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)}, c0_matrix_mask_{a_grid_desc_g_m_k_.GetLength(I1), b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
......
...@@ -352,12 +352,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -352,12 +352,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t Q_K1 = 8;
static constexpr index_t K_K1 = 8;
static constexpr index_t V_N1 = 2;
static constexpr index_t Q_M1 = 2;
static constexpr index_t K_N1 = 2;
static constexpr index_t V_O1 = 8; static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8; static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2; static constexpr index_t Y_M1 = 2;
...@@ -604,9 +598,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -604,9 +598,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{
return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -728,7 +726,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -728,7 +726,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -940,7 +938,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -940,7 +938,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
......
...@@ -359,12 +359,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -359,12 +359,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t Q_K1 = 8;
static constexpr index_t K_K1 = 8;
static constexpr index_t V_N1 = 2;
static constexpr index_t Q_M1 = 2;
static constexpr index_t K_N1 = 2;
static constexpr index_t V_O1 = 8; static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8; static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2; static constexpr index_t Y_M1 = 2;
...@@ -604,9 +598,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -604,9 +598,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{
return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -736,7 +734,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -736,7 +734,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -948,7 +946,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -948,7 +946,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment