Commit 7e003d31 authored by aska-0096's avatar aska-0096
Browse files

Porting new blockwise gemm to flash attention

parent 84b4ada5
...@@ -100,12 +100,12 @@ using DeviceGemmInstance = ...@@ -100,12 +100,12 @@ using DeviceGemmInstance =
32, // KPerBlock 32, // KPerBlock
8, // K1 8, // K1
// Gemm 1 // Gemm 1
64, // NPerBlock 64, // NPerBlock
32, // LPerBlock 32, // LTilePerBlock
8, // L1 8, // L1
16, // MPerWMMA 16, // MPerWMMA
16, // LPerWMMA 16, // LPerWMMA
16, // NPerWMMA 16, // NPerWMMA
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, // MRepeat 1, // MRepeat
8, // LRepeat 8, // LRepeat
...@@ -124,7 +124,7 @@ using DeviceGemmInstance = ...@@ -124,7 +124,7 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
S<4, 8, 8>, // B1BlockTransfer LN -> L0 N L1 S<4, 8, 8>, // B1BlockTransfer NL -> L0 N L1
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
......
...@@ -122,20 +122,20 @@ int run(int argc, char* argv[]) ...@@ -122,20 +122,20 @@ int run(int argc, char* argv[])
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break; break;
case 5: // Rand: b1 ; unit: a b0 fail case 5: // Rand: b1 b0; unit: a
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break; break;
case 6: // Rand: b0 ; unit: a b1 pass case 6: // Rand: a b0 ; unit: b1 pass
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break; break;
case 7: // Rand: a ; unit: b0 b1 pass case 7: // Rand: a b1 ; unit: b0 pass
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" #include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#define CK_MNK_LOOP #define CK_MNK_LOOP
...@@ -340,6 +341,7 @@ struct BlockwiseGemmWMMA ...@@ -340,6 +341,7 @@ struct BlockwiseGemmWMMA
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0),
b_thread_buf); b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, WmmaK> b_thread_vec;
...@@ -413,7 +415,7 @@ struct BlockwiseGemmWMMA ...@@ -413,7 +415,7 @@ struct BlockwiseGemmWMMA
A_K1, A_K1,
0x76543210, 0x76543210,
0xfedcba98, 0xfedcba98,
true>; TransposeC ? false : true>;
}; };
template <bool EnableLds> template <bool EnableLds>
...@@ -448,7 +450,7 @@ struct BlockwiseGemmWMMA ...@@ -448,7 +450,7 @@ struct BlockwiseGemmWMMA
B_K1, B_K1,
0x76543210, 0x76543210,
0xfedcba98, 0xfedcba98,
false>; TransposeC ? true : false>;
}; };
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_; typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
......
...@@ -56,11 +56,11 @@ template <index_t NumDimG, ...@@ -56,11 +56,11 @@ template <index_t NumDimG,
ck::index_t KPerBlock, ck::index_t KPerBlock,
ck::index_t K1, ck::index_t K1,
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t LPerBlock, ck::index_t LTilePerBlock,
ck::index_t L1, ck::index_t L1,
ck::index_t MPerWMMA, ck::index_t MPerWmma,
ck::index_t LPerWMMA, ck::index_t LPerWmma,
ck::index_t NPerWMMA, ck::index_t NPerWmma,
ck::index_t MRepeat, ck::index_t MRepeat,
ck::index_t LRepeat, ck::index_t LRepeat,
ck::index_t NRepeat, ck::index_t NRepeat,
...@@ -134,15 +134,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -134,15 +134,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto WmmaK = 16;
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = LWaves == 1 ? false : true; static constexpr auto AEnableLds = LWaves == 1 ? false : true;
// static constexpr auto B0EnableLds = MWaves == 1 ? false : true; static constexpr auto B0EnableLds = MWaves == 1 ? false : true;
// static constexpr auto B1EnableLds = MWaves == 1 ? false : true; static constexpr auto B1EnableLds = MWaves == 1 ? false : true;
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>, Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>,
...@@ -165,14 +168,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -165,14 +168,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
else else
{ {
return Transform::MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1( return Transform::MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
WmmaK, Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{}, Number<K1>{}) Number<WmmaK>{},
Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWmma>{},
Number<K1>{});
} }
} }
static auto MakeB0GridDescriptor_BK0_L_BK1(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec, static auto MakeB0GridDescriptor(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ls_ks_strides_vec) const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, b0_gs_ls_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, b0_gs_ls_ks_strides_vec),
...@@ -188,7 +194,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -188,7 +194,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
} }
using AGridDesc = decltype(MakeAGridDescriptor({}, {})); using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor_BK0_L_BK1({}, {})); using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor({}, {}));
using B1GridDesc_BL0_N_BL1 = decltype(MakeB1GridDescriptor_BL0_N_BL1({}, {})); using B1GridDesc_BL0_N_BL1 = decltype(MakeB1GridDescriptor_BL0_N_BL1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
...@@ -277,11 +283,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -277,11 +283,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
KPerBlock, KPerBlock,
K1, K1,
NPerBlock, NPerBlock,
LPerBlock, LTilePerBlock,
L1, L1,
MPerWMMA, MPerWmma,
LPerWMMA, LPerWmma,
NPerWMMA, NPerWmma,
MRepeat, MRepeat,
LRepeat, LRepeat,
NRepeat, NRepeat,
...@@ -357,10 +363,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -357,10 +363,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b0_grid_{p_b0_grid}, p_b0_grid_{p_b0_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, b0_grid_desc_bk0_l_bk1_{
b0_grid_desc_bk0_l_bk1_{DeviceOp::MakeB0GridDescriptor_BK0_L_BK1( DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_bl0_n_bl1_{DeviceOp::MakeB1GridDescriptor_BL0_N_BL1( b1_grid_desc_bl0_n_bl1_{DeviceOp::MakeB1GridDescriptor_BL0_N_BL1(
b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
c_grid_desc_m_n_{ c_grid_desc_m_n_{
...@@ -405,7 +410,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -405,7 +410,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ignore = acc1_biases_gs_ms_ns_lengths; ignore = acc1_biases_gs_ms_ns_lengths;
ignore = acc1_biases_gs_ms_ns_strides; ignore = acc1_biases_gs_ms_ns_strides;
if(GridwiseOp::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseOp::CheckValidity(a_grid_desc,
b0_grid_desc_bk0_l_bk1_, b0_grid_desc_bk0_l_bk1_,
b1_grid_desc_bl0_n_bl1_, b1_grid_desc_bl0_n_bl1_,
c_grid_desc_m_n_, c_grid_desc_m_n_,
...@@ -424,7 +429,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -424,7 +429,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType* p_c_grid_; CDataType* p_c_grid_;
// Tensor Descriptors // Tensor Descriptors
AGridDesc a_grid_desc_ak0_m_ak1_; AGridDesc a_grid_desc;
B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_; B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_;
B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_; B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
...@@ -473,8 +478,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -473,8 +478,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
const auto K = const auto K = [&]() {
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); if constexpr(AEnableLds)
{
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
}
else
{
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) *
arg.a_grid_desc.GetLength(I5);
}
}();
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
const auto kernel = kernel_batched_gemm_softmax_gemm_wmma_cshuffle< const auto kernel = kernel_batched_gemm_softmax_gemm_wmma_cshuffle<
...@@ -506,7 +520,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -506,7 +520,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg.p_b0_grid_, arg.p_b0_grid_,
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc,
arg.b0_grid_desc_bk0_l_bk1_, arg.b0_grid_desc_bk0_l_bk1_,
arg.b1_grid_desc_bl0_n_bl1_, arg.b1_grid_desc_bl0_n_bl1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
...@@ -551,20 +565,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -551,20 +565,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
printf("DeviceOp: Acc0 Type err");
return false; return false;
} }
if constexpr(!(is_same_v<Acc1DataType, float> || is_same_v<Acc1DataType, int32_t>)) if constexpr(!(is_same_v<Acc1DataType, float> || is_same_v<Acc1DataType, int32_t>))
{ {
printf("DeviceOp: Acc1 Type err");
return false; return false;
} }
} }
else else
{ {
printf("DeviceOp: Arch err");
return false; return false;
} }
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc_bk0_l_bk1_, arg.b0_grid_desc_bk0_l_bk1_,
arg.b1_grid_desc_bl0_n_bl1_, arg.b1_grid_desc_bl0_n_bl1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -574,14 +591,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -574,14 +591,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
} }
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_n = arg.b1_grid_desc_bl0_n_bl1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_n == b1_n)) if(!(c_g == arg.batch_count_))
{ {
printf("DeviceOp: BatchCount err");
return false; return false;
} }
...@@ -604,6 +618,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -604,6 +618,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{ {
printf("DeviceOp: Data Transfer Vector scalar err");
return false; return false;
} }
...@@ -619,6 +634,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -619,6 +634,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1)) c_stride_lowest == 1))
{ {
printf("DeviceOp: Data Vectorize transfer err");
return false; return false;
} }
...@@ -765,7 +781,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -765,7 +781,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<< K1 << ", " << K1 << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< LPerBlock << ", " << LTilePerBlock << ", "
<< L1 << L1
<< getGemmSpecializationString(GemmSpec) << ", " << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", " << "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
...@@ -23,7 +23,7 @@ template <typename GridwiseGemm, ...@@ -23,7 +23,7 @@ template <typename GridwiseGemm,
typename FloatB0, typename FloatB0,
typename FloatB1, typename FloatB1,
typename FloatC, typename FloatC,
typename AGridDesc_AK0_M_AK1, typename AGridDesc,
typename B0GridDesc_BK0_L_BK1, typename B0GridDesc_BK0_L_BK1,
typename B1GridDesc_BL0_N_BL1, typename B1GridDesc_BL0_N_BL1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -45,7 +45,7 @@ __global__ void ...@@ -45,7 +45,7 @@ __global__ void
const FloatB0* __restrict__ p_b0_grid, const FloatB0* __restrict__ p_b0_grid,
const FloatB1* __restrict__ p_b1_grid, const FloatB1* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc a_grid_desc,
const B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1, const B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1,
const B1GridDesc_BL0_N_BL1 b1_grid_desc_l0_n_l1, const B1GridDesc_BL0_N_BL1 b1_grid_desc_l0_n_l1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -81,7 +81,7 @@ __global__ void ...@@ -81,7 +81,7 @@ __global__ void
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_shared, p_shared,
a_grid_desc_ak0_m_ak1, a_grid_desc,
b0_grid_desc_bk0_l_bk1, b0_grid_desc_bk0_l_bk1,
b1_grid_desc_l0_n_l1, b1_grid_desc_l0_n_l1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -97,7 +97,7 @@ __global__ void ...@@ -97,7 +97,7 @@ __global__ void
ignore = p_b0_grid; ignore = p_b0_grid;
ignore = p_b1_grid; ignore = p_b1_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc;
ignore = b0_grid_desc_bk0_l_bk1; ignore = b0_grid_desc_bk0_l_bk1;
ignore = b1_grid_desc_l0_n_l1; ignore = b1_grid_desc_l0_n_l1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -128,7 +128,7 @@ template <typename FloatA, ...@@ -128,7 +128,7 @@ template <typename FloatA,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc,
typename B0GridDesc_BK0_L_BK1, typename B0GridDesc_BK0_L_BK1,
typename B1GridDesc_BL0_N_BL1, typename B1GridDesc_BL0_N_BL1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
...@@ -137,7 +137,7 @@ template <typename FloatA, ...@@ -137,7 +137,7 @@ template <typename FloatA,
index_t KPerBlock, index_t KPerBlock,
index_t K1Value, index_t K1Value,
index_t NPerBlock, index_t NPerBlock,
index_t LPerBlock, index_t LTilePerBlock,
index_t L1Value, index_t L1Value,
index_t MPerWmma, index_t MPerWmma,
index_t LPerWmma, index_t LPerWmma,
...@@ -194,14 +194,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -194,14 +194,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto AK1 = Number<K1Value>{}; static constexpr auto AK1 = Number<K1Value>{};
static constexpr auto BK0 = Number<KPerBlock/K1Value>{}; static constexpr auto BK0 = Number<KPerBlock / K1Value>{};
static constexpr auto BK1 = Number<K1Value>{}; static constexpr auto BK1 = Number<K1Value>{};
static constexpr auto L0PerBlock = LPerBlock / L1Value; static constexpr auto L0PerBlock = LTilePerBlock / L1Value;
static constexpr auto AL0 = Number<L0PerBlock / 2>{}; static constexpr auto AL0 = Number<L0PerBlock / 2>{};
static constexpr auto AL1 = Number<L1Value>{}; static constexpr auto AL1 = Number<L1Value>{};
static constexpr auto BL0 = Number<L0PerBlock>{}; static constexpr auto BL0 = Number<L0PerBlock>{};
static constexpr auto BL1 = Number<L1Value>{}; static constexpr auto BL1 = Number<L1Value>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
...@@ -209,8 +209,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -209,8 +209,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe =
GridwiseGemmPipeline_Selector<PipelineVer, AEnableLds, B0EnableLds,NumGemmKPrefetchStage, LoopSched>())>; remove_cvref_t<decltype(GridwiseGemmPipeline_Selector<PipelineVer,
AEnableLds,
B0EnableLds,
NumGemmKPrefetchStage,
LoopSched>())>;
__host__ __device__ static constexpr auto MakeABlockDescriptor() __host__ __device__ static constexpr auto MakeABlockDescriptor()
{ {
...@@ -238,7 +242,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -238,7 +242,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread // KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<MRepeat>{}, I1, I1, I1, K1), make_tuple(Number<KWmmaPerblock>{}, Number<MRepeat>{}, I1, I1, I1, AK1),
make_tuple(Number<MRepeat>{} * AK1, AK1, AK1, AK1, AK1, I1)); make_tuple(Number<MRepeat>{} * AK1, AK1, AK1, AK1, AK1, I1));
} }
}(); }();
...@@ -349,9 +353,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -349,9 +353,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeB1BlockDescriptor_L0_N0_N1_N2_L1(const B1BlockDesc_BL0_N_BL1&) MakeB1BlockDescriptor_L0_N0_N1_N2_L1(const B1BlockDesc_BL0_N_BL1&)
{ {
constexpr index_t B_K0 = B1BlockDesc_BL0_N_BL1{}.GetLength(I0); constexpr index_t B_K0 = B1BlockDesc_BL0_N_BL1{}.GetLength(I0);
constexpr index_t B_K1 = B1BlockDesc_BL0_N_BL1{}.GetLength(I2); constexpr index_t B_K1 = B1BlockDesc_BL0_N_BL1{}.GetLength(I2);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
return transform_tensor_descriptor( return transform_tensor_descriptor(
B1BlockDesc_BL0_N_BL1{}, B1BlockDesc_BL0_N_BL1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}), make_tuple(make_pass_through_transform(Number<B_K0>{}),
...@@ -399,16 +403,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -399,16 +403,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
const index_t gemm0_bytes_end = const index_t gemm0_bytes_end =
(SharedMemTrait::a_block_space_size_aligned + (SharedMemTrait::a_block_space_size_aligned * sizeof(FloatA) +
SharedMemTrait::b0_block_space_size_aligned); SharedMemTrait::b0_block_space_size_aligned * sizeof(FloatB0));
const index_t gemm1_bytes_end = const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned); (SharedMemTrait::b1_block_space_offset +
SharedMemTrait::b1_block_space_size_aligned * sizeof(FloatB1));
const index_t softmax_bytes_end = SharedMemTrait::reduction_space_offset + const index_t softmax_bytes_end =
SharedMemTrait::reduction_space_size_aligned SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned * sizeof(FloatAcc0);
const index_t c_block_bytes_end = SharedMemTrait::c_block_space_size; const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
} }
...@@ -416,7 +423,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -416,7 +423,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, CheckValidity(const AGridDesc& a_grid_desc,
const B0GridDesc_BK0_L_BK1& b0_grid_desc_bk0_l_bk1, const B0GridDesc_BK0_L_BK1& b0_grid_desc_bk0_l_bk1,
const B1GridDesc_BL0_N_BL1& b1_grid_desc_l0_n_l1, const B1GridDesc_BL0_N_BL1& b1_grid_desc_l0_n_l1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
...@@ -426,19 +433,48 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -426,19 +433,48 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
(LPerBlock % (LPerWmma * LRepeat)) == 0, (LPerBlock % (LPerWmma * LRepeat)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); const auto GetAProblemsizeMK = [&]() {
if constexpr(AEnableLds)
{
return make_tuple(a_grid_desc.GetLength(I1),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
}
else
{
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
a_grid_desc.GetLength(I4),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I5));
}
};
const auto M = GetAProblemsizeMK()[I0];
const auto L = b0_grid_desc_bk0_l_bk1.GetLength(I1); const auto L = b0_grid_desc_bk0_l_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K = GetAProblemsizeMK()[I1];
const auto N = b1_grid_desc_l0_n_l1.GetLength(I1); const auto N = b1_grid_desc_l0_n_l1.GetLength(I1);
const auto KPerBlock = K0PerBlock * K1Value;
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
{ {
printf("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n",
M,
N,
c_grid_desc_m_n.GetLength(I0),
c_grid_desc_m_n.GetLength(I1));
return false; return false;
} }
if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0))
{ {
printf("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = "
"%d, %d, %d, %d\n",
M,
L,
K,
N,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock);
return false; return false;
} }
...@@ -446,18 +482,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -446,18 +482,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const auto num_gemm0_k_loop = K / KPerBlock; const auto num_gemm0_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
{ {
printf("GridwiseOp: outer loop unsupport\n");
return false; return false;
} }
// check gemm1 gridwise gemm pipeline // check gemm1 gridwise gemm pipeline
if(!(LPerBlock % (L0PerBlock * L1Value) == 0)) if(!(LPerBlock % LTilePerBlock == 0))
{ {
printf("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n",
LPerBlock,
LTilePerBlock);
return false; return false;
} }
const auto num_gemm1_k_inner_loop = LPerBlock / (L0PerBlock * L1Value); const auto num_gemm1_k_inner_loop = LPerBlock / LTilePerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
{ {
printf("GridwiseOp: inner loop unsupport\n");
return false; return false;
} }
...@@ -472,7 +513,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -472,7 +513,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / (K0PerBlock * K1Value); const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
...@@ -514,28 +555,38 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -514,28 +555,38 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1); static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1);
static constexpr auto a_block_space_size_aligned = AEnableLds ? math::integer_least_multiple( static constexpr auto a_block_space_size_aligned =
MakeABlockDescriptor().GetElementSpaceSize() * sizeof(FloatA), max_lds_align) : 0; AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
static constexpr auto b0_block_space_size_aligned = B0EnableLds ? math::integer_least_multiple( max_lds_align)
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1().GetElementSpaceSize() * sizeof(FloatB0), max_lds_align) : 0; : 0;
static constexpr auto b1_block_space_size_aligned = B1EnableLds ? math::integer_least_multiple( static constexpr auto b0_block_space_size_aligned =
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1().GetElementSpaceSize() * sizeof(FloatB1), max_lds_align) : 0; B0EnableLds
? math::integer_least_multiple(
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1().GetElementSpaceSize(),
max_lds_align)
: 0;
static constexpr auto b1_block_space_size_aligned =
B1EnableLds
? math::integer_least_multiple(
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1().GetElementSpaceSize(),
max_lds_align)
: 0;
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
static constexpr auto b0_block_space_offset = a_block_space_size_aligned.value; static constexpr auto b0_block_space_offset = a_block_space_size_aligned;
static constexpr auto b1_block_space_offset = 0; static constexpr auto b1_block_space_offset = 0;
// LDS allocation for reduction // LDS allocation for reduction
// Feature to add, IntraThread Reduction // Feature to add, IntraThread Reduction
static constexpr index_t reduction_space_size_aligned = static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align) * sizeof(FloatAcc0); math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0; static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_block_space_size = static constexpr auto c_block_space_size =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
.GetElementSpaceSize() * sizeof(FloatCShuffle); .GetElementSpaceSize();
}; };
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
...@@ -546,7 +597,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -546,7 +597,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const FloatB1* __restrict__ p_b1_grid, const FloatB1* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_AK0_M_AK1& a_grid_desc_k0_m_k1, const AGridDesc& a_grid_desc,
const B0GridDesc_BK0_L_BK1& b0_grid_desc_k0_l_k1, const B0GridDesc_BK0_L_BK1& b0_grid_desc_k0_l_k1,
const B1GridDesc_BL0_N_BL1& b1_grid_desc_l0_n_l1, const B1GridDesc_BL0_N_BL1& b1_grid_desc_l0_n_l1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
...@@ -563,7 +614,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -563,7 +614,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/ /*******************************************************************************/
// Memory buffer zone. // Memory buffer zone.
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc.GetElementSpaceSize());
const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b0_grid, b0_grid_desc_k0_l_k1.GetElementSpaceSize()); p_b0_grid, b0_grid_desc_k0_l_k1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -601,7 +652,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -601,7 +652,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto a_block_desc = MakeABlockDescriptor(); constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b0_block_desc_k0perblock_lperblock_k1 = GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1(); constexpr auto b0_block_desc_k0perblock_lperblock_k1 = GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1();
auto a_block_trait = [&](){ auto a_block_trait = [&](){
// A matrix blockwise copy // A matrix blockwise copy
if constexpr(AEnableLds) if constexpr(AEnableLds)
...@@ -610,17 +661,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -610,17 +661,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatA*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<FloatA*>(p_shared) + SharedMemTrait::a_block_space_offset,
SharedMemTrait::a_block_space_size_aligned); SharedMemTrait::a_block_space_size_aligned);
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
/* typename SrcElementwiseOperation, */ AElementwiseOperation, /* typename SrcElementwiseOperation, */ AElementwiseOperation,
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough, /* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, /* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
/* typename BlockSliceLengths, */ Sequence<AK0, MPerBlock, AK1>, /* typename BlockSliceLengths, */ Sequence<AK0PerBlock, MPerBlock, AK1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatA, /* typename SrcData, */ FloatA,
/* typename DstData, */ FloatA, /* typename DstData, */ FloatA,
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), /* typename SrcDesc, */ decltype(a_grid_desc),
/* typename DstDesc, */ decltype(a_block_desc), /* typename DstDesc, */ decltype(a_block_desc),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, /* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
...@@ -632,7 +684,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -632,7 +684,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* index_t DstScalarStrideInVector, */ 1, /* index_t DstScalarStrideInVector, */ 1,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, /* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>( /* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
a_grid_desc_k0_m_k1, a_grid_desc,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
a_block_desc, a_block_desc,
...@@ -713,7 +765,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -713,7 +765,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/ /*******************************************************************************/
// Gemm0 // Gemm0
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1Value, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1Value, WmmaK);
auto blockwise_gemm0 = BlockwiseGemmWMMA< auto blockwise_gemm0 = BlockwiseGemmWMMA<
...@@ -725,7 +776,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -725,7 +776,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
decltype(MakeB0BlockDescriptor_K0_L0_L1_L2_K1(b0_block_desc_k0perblock_lperblock_k1)), decltype(MakeB0BlockDescriptor_K0_L0_L1_L2_K1(b0_block_desc_k0perblock_lperblock_k1)),
MPerBlock, MPerBlock,
LPerBlock, LPerBlock,
K0PerBlock * K1Value, KPerBlock,
MPerWmma, MPerWmma,
LPerWmma, LPerWmma,
MRepeat, MRepeat,
...@@ -759,18 +810,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -759,18 +810,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/ /*******************************************************************************/
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB0*>(p_shared) + SharedMemTrait::b0_block_space_offset, auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
b0_block_desc_k0perblock_lperblock_k1.GetElementSpaceSize()); static_cast<FloatB0*>(p_shared) + SharedMemTrait::b0_block_space_offset,
SharedMemTrait::b0_block_space_size_aligned);
// Shift Per SUB_K // Shift Per SUB_K
constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
constexpr auto b0_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b0_block_slice_copy_step = make_multi_index(BK0, 0, 0);
const auto a_block_reset_copy_step = [&](){ const auto a_block_reset_copy_step = [&](){
if constexpr(AEnableLds){ if constexpr(AEnableLds){
return make_multi_index(-a_grid_desc_k0_m_k1.GetLength(I0), 0, 0); return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0);
}
else{ else{
return make_multi_index(-a_grid_desc_k0_m_k1.GetLength(I0), 0, 0, 0, 0, 0); return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0, 0, 0, 0);
} }
}(); }();
...@@ -836,24 +889,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -836,24 +889,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1(); constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1();
constexpr auto b1_block_slice_copy_step = make_multi_index(BL0, 0, 0); constexpr auto b1_block_slice_copy_step = make_multi_index(BL0, 0, 0);
// Acc0 thread buffer -> A1 thread buffer -> blockwise gemm
// A1 matrix in VGPR // A1 matrix in VGPR
constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple( constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple(
Number<AL0 * AL1 / laccvgprs>{}, Number<AL0 * AL1 / laccvgprs>{},
Number<mrepeat * mwave * mthreadpersubgroup>{}, Number<mrepeat * mwave * mthreadpersubgroup>{},
Number<laccvgprs>{}); // Data duplicated dimension Number<laccvgprs>{});
constexpr auto A1ThreadSliceL0PerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I0]; constexpr auto A1ThreadSliceL0PerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I0];
constexpr auto A1ThreadSliceMPerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I1]; constexpr auto A1ThreadSliceMPerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I1];
constexpr auto A1ThreadSliceL1 = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I2]; constexpr auto A1ThreadSliceL1 = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I2];
// A1 has duplicated data
constexpr auto A1ThreadDuplicatedDim = I2 * A1ThreadSliceL1;
constexpr auto a1_thread_desc_l0perblock_mperblock_l1 = make_naive_tensor_descriptor( constexpr auto a1_thread_desc_l0perblock_mperblock_l1 = make_naive_tensor_descriptor(
make_tuple(A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadDuplicatedDim), make_tuple(A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadSliceL1),
make_tuple(A1ThreadSliceMPerBlock * A1ThreadDuplicatedDim, A1ThreadDuplicatedDim, I1)); make_tuple(A1ThreadSliceMPerBlock * A1ThreadSliceL1, A1ThreadSliceL1, I1));
// A1 matrix blockwise copy // A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatAcc0, FloatAcc0,
FloatA, FloatA,
decltype(acc0_thread_desc_l0perblock_mperblock_l1), decltype(acc0_thread_desc_l0perblock_mperblock_l1),
...@@ -862,13 +914,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -862,13 +914,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
Sequence<A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadSliceL1>, Sequence<A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadSliceL1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
laccvgprs, laccvgprs>{tensor_operation::element_wise::PassThrough{}};
// dst Rowlane
// 0x76543210 0xfedcba98
// src Rowlane
0x76543210, 0xfedcba98,
false>{};
// B1 matrix blockwise copy // B1 matrix blockwise copy
auto b1_blockwise_copy = auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
...@@ -904,7 +951,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -904,7 +951,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize()); a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize());
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB1*>(p_shared)+ SharedMemTrait::b1_block_space_offset, static_cast<FloatB1*>(p_shared)+ SharedMemTrait::b1_block_space_offset,
b1_block_desc_l0perblock_nperblock_l1.GetElementSpaceSize()); SharedMemTrait::b1_block_space_size_aligned);
auto blockwise_gemm1 = auto blockwise_gemm1 =
BlockwiseGemmWMMA<BlockSize, BlockwiseGemmWMMA<BlockSize,
...@@ -915,7 +962,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -915,7 +962,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
decltype(MakeB1BlockDescriptor_L0_N0_N1_N2_L1(b1_block_desc_l0perblock_nperblock_l1)), decltype(MakeB1BlockDescriptor_L0_N0_N1_N2_L1(b1_block_desc_l0perblock_nperblock_l1)),
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
BL0 * BL1, LTilePerBlock,
MPerWmma, MPerWmma,
NPerWmma, NPerWmma,
MRepeat, MRepeat,
...@@ -926,13 +973,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -926,13 +973,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
const index_t num_gemm1_l_block_outer_loop = b0_grid_desc_k0_l_k1.GetLength(I1) / LPerBlock; const index_t num_gemm1_l_block_outer_loop = b0_grid_desc_k0_l_k1.GetLength(I1) / LPerBlock;
constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / (BL0 * BL1); constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / LTilePerBlock;
// Initialize C // Initialize C
StaticBuffer<AddressSpaceEnum::Vgpr, FloatAcc1, acc1_thread_buf.Size(), true> c_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, FloatAcc1, acc1_thread_buf.Size(), true> c_thread_buf;
c_thread_buf.Clear(); c_thread_buf.Clear();
/*******************************************************************************/ /*******************************************************************************/
//
// Kernel Main Stage
//
// Flash Attention // Flash Attention
// Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022). // Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022).
index_t gemm1_l_block_outer_index = 0; index_t gemm1_l_block_outer_index = 0;
...@@ -947,7 +997,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -947,7 +997,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
continue; continue;
} }
// gemm0 start, A-B swaped // gemm0 start, A-B swaped
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1, GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
a_block_desc, a_block_desc,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
...@@ -1019,10 +1069,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -1019,10 +1069,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
[&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); }); [&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
} }
block_sync_lds(); block_sync_lds();
// gemm0 end
// gemm0 incorrect
// Tiled softmax start // Tiled softmax start
// softmax // softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf; SoftmaxBuf& max = blockwise_softmax.max_value_buf;
...@@ -1130,7 +1177,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -1130,7 +1177,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
}); });
}); });
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc,
a_block_reset_copy_step); // rewind K a_block_reset_copy_step); // rewind K
b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc_k0_l_k1, b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc_k0_l_k1,
b0_block_reset_copy_step); // rewind K and step N b0_block_reset_copy_step); // rewind K and step N
......
...@@ -179,24 +179,32 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -179,24 +179,32 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template <typename AGridDesc_M_K, typename Number> template <typename AGridDesc_M_K,
typename WmmaK,
typename MRepeat,
typename MWaves,
typename MPerWmma,
typename AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1( MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1(
const AGridDesc_M_K& a_grid_desc_m_k, const Number& WmmaK, const Number& MRepeat, const AGridDesc_M_K& a_grid_desc_m_k,
const Number& MWaves, const Number& MPerWmma, const Number& AK1) const WmmaK&,
const MRepeat&,
const MWaves&,
const MPerWmma&,
const AK1&)
{ {
const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlcok; const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock;
const auto K = a_grid_desc_m_k.GetLength(I1); const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AKWmma = K / WmmaK; const auto AKWmma = K / WmmaK{};
constexpr auto AKRow = WmmaK / K1; constexpr auto AKRow = WmmaK{} / AK1{};
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AKWmma, Number<AKRow>{}, AK1)), make_tuple(make_unmerge_transform(make_tuple(AKWmma, AKRow, AK1{})),
make_unmerge_transform( make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))),
make_tuple(M0 * MRepeat, MWaves, MPerWmma))), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
} }
// //
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment