Commit 7e003d31 authored by aska-0096's avatar aska-0096
Browse files

Porting new blockwise gemm to flash attention

parent 84b4ada5
......@@ -100,12 +100,12 @@ using DeviceGemmInstance =
32, // KPerBlock
8, // K1
// Gemm 1
64, // NPerBlock
32, // LPerBlock
8, // L1
16, // MPerWMMA
16, // LPerWMMA
16, // NPerWMMA
64, // NPerBlock
32, // LTilePerBlock
8, // L1
16, // MPerWMMA
16, // LPerWMMA
16, // NPerWMMA
// Per repeat = wave_m = wave_num, wave_n = 1
1, // MRepeat
8, // LRepeat
......@@ -124,7 +124,7 @@ using DeviceGemmInstance =
8,
8,
true,
S<4, 8, 8>, // B1BlockTransfer LN -> L0 N L1
S<4, 8, 8>, // B1BlockTransfer NL -> L0 N L1
S<0, 2, 1>,
S<0, 2, 1>,
1,
......
......@@ -122,20 +122,20 @@ int run(int argc, char* argv[])
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5: // Rand: b1 ; unit: a b0 fail
case 5: // Rand: b1 b0; unit: a
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 6: // Rand: b0 ; unit: a b1 pass
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
case 6: // Rand: a b0 ; unit: b1 pass
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 7: // Rand: a ; unit: b0 b1 pass
case 7: // Rand: a b1 ; unit: b0 pass
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
......
......@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#define CK_MNK_LOOP
......@@ -340,6 +341,7 @@ struct BlockwiseGemmWMMA
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
......@@ -413,7 +415,7 @@ struct BlockwiseGemmWMMA
A_K1,
0x76543210,
0xfedcba98,
true>;
TransposeC ? false : true>;
};
template <bool EnableLds>
......@@ -448,7 +450,7 @@ struct BlockwiseGemmWMMA
B_K1,
0x76543210,
0xfedcba98,
false>;
TransposeC ? true : false>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
......
......@@ -56,11 +56,11 @@ template <index_t NumDimG,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t NPerBlock,
ck::index_t LPerBlock,
ck::index_t LTilePerBlock,
ck::index_t L1,
ck::index_t MPerWMMA,
ck::index_t LPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MPerWmma,
ck::index_t LPerWmma,
ck::index_t NPerWmma,
ck::index_t MRepeat,
ck::index_t LRepeat,
ck::index_t NRepeat,
......@@ -134,15 +134,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto WmmaK = 16;
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = LWaves == 1 ? false : true;
// static constexpr auto B0EnableLds = MWaves == 1 ? false : true;
// static constexpr auto B1EnableLds = MWaves == 1 ? false : true;
static constexpr auto AEnableLds = LWaves == 1 ? false : true;
static constexpr auto B0EnableLds = MWaves == 1 ? false : true;
static constexpr auto B1EnableLds = MWaves == 1 ? false : true;
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>,
......@@ -165,14 +168,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
else
{
return Transform::MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
WmmaK, Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{}, Number<K1>{})
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<WmmaK>{},
Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWmma>{},
Number<K1>{});
}
}
static auto MakeB0GridDescriptor_BK0_L_BK1(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
static auto MakeB0GridDescriptor(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, b0_gs_ls_ks_strides_vec),
......@@ -188,7 +194,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor_BK0_L_BK1({}, {}));
using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor({}, {}));
using B1GridDesc_BL0_N_BL1 = decltype(MakeB1GridDescriptor_BL0_N_BL1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
......@@ -277,11 +283,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
KPerBlock,
K1,
NPerBlock,
LPerBlock,
LTilePerBlock,
L1,
MPerWMMA,
LPerWMMA,
NPerWMMA,
MPerWmma,
LPerWmma,
NPerWmma,
MRepeat,
LRepeat,
NRepeat,
......@@ -357,10 +363,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b0_grid_{p_b0_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b0_grid_desc_bk0_l_bk1_{DeviceOp::MakeB0GridDescriptor_BK0_L_BK1(
b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b0_grid_desc_bk0_l_bk1_{
DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_bl0_n_bl1_{DeviceOp::MakeB1GridDescriptor_BL0_N_BL1(
b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
c_grid_desc_m_n_{
......@@ -405,7 +410,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ignore = acc1_biases_gs_ms_ns_lengths;
ignore = acc1_biases_gs_ms_ns_strides;
if(GridwiseOp::CheckValidity(a_grid_desc_ak0_m_ak1_,
if(GridwiseOp::CheckValidity(a_grid_desc,
b0_grid_desc_bk0_l_bk1_,
b1_grid_desc_bl0_n_bl1_,
c_grid_desc_m_n_,
......@@ -424,7 +429,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType* p_c_grid_;
// Tensor Descriptors
AGridDesc a_grid_desc_ak0_m_ak1_;
AGridDesc a_grid_desc;
B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_;
B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_;
CGridDesc_M_N c_grid_desc_m_n_;
......@@ -473,8 +478,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const auto K = [&]() {
if constexpr(AEnableLds)
{
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
}
else
{
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) *
arg.a_grid_desc.GetLength(I5);
}
}();
auto launch_kernel = [&](auto has_main_k_block_loop) {
const auto kernel = kernel_batched_gemm_softmax_gemm_wmma_cshuffle<
......@@ -506,7 +520,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg.p_b0_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.a_grid_desc_ak0_m_ak1_,
arg.a_grid_desc,
arg.b0_grid_desc_bk0_l_bk1_,
arg.b1_grid_desc_bl0_n_bl1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -551,20 +565,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
printf("DeviceOp: Acc0 Type err");
return false;
}
if constexpr(!(is_same_v<Acc1DataType, float> || is_same_v<Acc1DataType, int32_t>))
{
printf("DeviceOp: Acc1 Type err");
return false;
}
}
else
{
printf("DeviceOp: Arch err");
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc_bk0_l_bk1_,
arg.b1_grid_desc_bl0_n_bl1_,
arg.c_grid_desc_m_n_,
......@@ -574,14 +591,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_n = arg.b1_grid_desc_bl0_n_bl1_.GetLength(I1);
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
if(!(c_g == arg.batch_count_ && c_m == a_m && c_n == b1_n))
if(!(c_g == arg.batch_count_))
{
printf("DeviceOp: BatchCount err");
return false;
}
......@@ -604,6 +618,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
printf("DeviceOp: Data Transfer Vector scalar err");
return false;
}
......@@ -619,6 +634,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
printf("DeviceOp: Data Vectorize transfer err");
return false;
}
......@@ -765,7 +781,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<< K1 << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< LPerBlock << ", "
<< LTilePerBlock << ", "
<< L1
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
......@@ -179,24 +179,32 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename AGridDesc_M_K, typename Number>
template <typename AGridDesc_M_K,
typename WmmaK,
typename MRepeat,
typename MWaves,
typename MPerWmma,
typename AK1>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1(
const AGridDesc_M_K& a_grid_desc_m_k, const Number& WmmaK, const Number& MRepeat,
const Number& MWaves, const Number& MPerWmma, const Number& AK1)
const AGridDesc_M_K& a_grid_desc_m_k,
const WmmaK&,
const MRepeat&,
const MWaves&,
const MPerWmma&,
const AK1&)
{
const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlcok;
const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock;
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AKWmma = K / WmmaK;
constexpr auto AKRow = WmmaK / K1;
const auto AKWmma = K / WmmaK{};
constexpr auto AKRow = WmmaK{} / AK1{};
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AKWmma, Number<AKRow>{}, AK1)),
make_unmerge_transform(
make_tuple(M0 * MRepeat, MWaves, MPerWmma))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AKWmma, AKRow, AK1{})),
make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
//
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment