Commit 7b915a10 authored by danyao12's avatar danyao12
Browse files

bwd qloop 2 kernels update mask

parent 2018bd28
......@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define PRINT_HOST 0
#define USING_MASK 1
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
......@@ -268,7 +268,7 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 253;
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = DIM;
ck::index_t O = DIM;
......
......@@ -86,7 +86,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
#else
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......@@ -228,8 +228,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......
......@@ -68,7 +68,7 @@ using GemmDataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = INT32; // U16
using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
using DDataType = F32;
......@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
#else
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......@@ -227,8 +227,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......
......@@ -359,12 +359,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t Q_K1 = 8;
static constexpr index_t K_K1 = 8;
static constexpr index_t V_N1 = 2;
static constexpr index_t Q_M1 = 2;
static constexpr index_t K_N1 = 2;
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
......@@ -666,9 +660,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{
return MaskDisabledPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{
return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskOutUpperTrianglePredicate{};
return MaskUpperTriangleFromBottomRightPredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......@@ -791,7 +789,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>;
// GridwiseYDotYGrad
......@@ -892,7 +890,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
c0_matrix_mask_{a_grid_desc_g_m_k_.GetLength(I1), b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
......
......@@ -365,12 +365,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t Q_K1 = 8;
static constexpr index_t K_K1 = 8;
static constexpr index_t V_N1 = 2;
static constexpr index_t Q_M1 = 2;
static constexpr index_t K_N1 = 2;
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
......@@ -672,9 +666,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
return MaskDisabledPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{
return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskOutUpperTrianglePredicate{};
return MaskUpperTriangleFromBottomRightPredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......@@ -805,7 +803,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>;
// GridwiseYDotYGrad
......@@ -905,7 +903,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
c0_matrix_mask_{a_grid_desc_g_m_k_.GetLength(I1), b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
......
......@@ -352,12 +352,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t Q_K1 = 8;
static constexpr index_t K_K1 = 8;
static constexpr index_t V_N1 = 2;
static constexpr index_t Q_M1 = 2;
static constexpr index_t K_N1 = 2;
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
......@@ -604,9 +598,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{
return MaskDisabledPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{
return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskOutUpperTrianglePredicate{};
return MaskUpperTriangleFromBottomRightPredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......@@ -728,7 +726,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
......@@ -940,7 +938,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp;
......
......@@ -359,12 +359,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t Q_K1 = 8;
static constexpr index_t K_K1 = 8;
static constexpr index_t V_N1 = 2;
static constexpr index_t Q_M1 = 2;
static constexpr index_t K_N1 = 2;
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
......@@ -604,9 +598,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
return MaskDisabledPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{
return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskOutUpperTrianglePredicate{};
return MaskUpperTriangleFromBottomRightPredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......@@ -736,7 +734,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
......@@ -948,7 +946,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment