Commit 465f8e6a authored by Adam Osewski's avatar Adam Osewski
Browse files

Update B LDS layout and setup tile distribution pattern at class level.

parent 69d6660c
...@@ -234,24 +234,33 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -234,24 +234,33 @@ struct UniversalGemmPipelineAgBgCrPolicy
return a_lds_block_desc; return a_lds_block_desc;
} }
/**
* @brief Create LDS block descriptor for B tensor.
*
* @tparam Problem Gemm pipeline problem.
* @return B tensor LDS block descriptor.
*/
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{ {
// using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackB<Problem>();
#if 1
// if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto BK0 = number<KPerBlock / KPack>{};
constexpr auto DataTypeSize = sizeof(BDataType); constexpr auto DataTypeSize = sizeof(BDataType);
constexpr auto NLdsLayer = constexpr auto NLdsLayer =
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack * NLdsLayer>{}, make_tuple(
number<NPerBlock / NLdsLayer>{}, BK0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, number<KPack>{}),
number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}), make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
number<KPack>{}, number<KPack>{},
number<1>{}); number<1>{});
...@@ -259,31 +268,160 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -259,31 +268,160 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc_0, b_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{}, make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
number<KPerBlock / KPack * NLdsLayer>{})), BK0 * number<NLdsLayer>{})),
make_pass_through_transform(number<KPack>{})), make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}), make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{})); make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted, b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform( make_tuple(make_unmerge_transform(make_tuple(BK0, number<NLdsLayer>{})),
make_tuple(number<KPerBlock / KPack>{}, number<NLdsLayer>{})),
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}), make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<KPack>{})), make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor( constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_xk0_mnldslayer_mn_xk1, b_lds_block_desc_bk0_nldslayer_n_bk1,
make_tuple(make_merge_transform_v3_division_mod( make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})), make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform_v3_division_mod( make_merge_transform_v3_division_mod(make_tuple(BK0, number<KPack>{}))),
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
// make_tuple(sequence<1>{}, sequence<0>{}));
return b_lds_block_desc; return b_lds_block_desc;
} }
#else
else // B is Row Major
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
constexpr auto BK0 = number<TileEncodingPattern::X1>{};
constexpr auto BK1 = number<TileEncodingPattern::Y0>{};
// constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr auto N0 = TileEncodingPattern::X0;
constexpr auto N1 = NPerBlock / N0;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr auto NPerXdl = number<WarpTile::at(I1)>{};
// constexpr auto KThreadWrite =
// BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr auto KThreadWrite = TileEncodingPattern::Y2;
constexpr auto K0PerThreadWrite = BK0 / KThreadWrite;
constexpr auto KThreadRead = 64 / NPerXdl;
constexpr auto K0PerThreadRead = BK0 / KThreadRead;
constexpr auto kfold =
(BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128)
? 1
: ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0
? N0
: 128 / (BK1 * NPerXdl * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * N1>{},
number<kfold * N0 / npair>{},
number<npair>{},
BK1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(BK1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(BK1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
// constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
// b_lds_block_desc_unmerged,
// make_tuple(make_merge_transform_v3_division_mod(
// make_tuple(number<KThreadReadPerm>{},
// number<KThreadWrite / kfold / KThreadReadPerm>{},
// number<kfold>{},
// number<K0PerThreadWrite>{})),
// make_merge_transform_v3_division_mod(
// make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{})),
// make_pass_through_transform(BK1)),
// make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}),
// make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
BK1)),
make_merge_transform_v3_division_mod(
make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
// return b_lds_block_desc_bk0_n_bk1;
return b_lds_block_desc_kn;
// constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor(
// make_tuple(BK0, number<NPerBlock>{}, number<KPack>{}),
// make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
// number<KPack>{},
// number<1>{});
// constexpr auto b_lds_block_desc = transform_tensor_descriptor(
// b_lds_block_desc_bk0_n_bk1,
// make_tuple(make_pass_through_transform(number<NPerBlock>{}),
// make_merge_transform_v3_division_mod(make_tuple(BK0,
// number<KPack>{}))),
// make_tuple(sequence<1>{}, sequence<0, 2>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
// return b_lds_block_desc;
}
#endif
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
...@@ -326,24 +464,22 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -326,24 +464,22 @@ struct UniversalGemmPipelineAgBgCrPolicy
{ {
// We should take layout into account! // We should take layout into account!
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>(); constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr auto AccessPattern = tile_distribution_pattern::thread_raked;
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize, using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
MPerBlock, MPerBlock,
KPerBlock, KPerBlock,
VecLoadSize, VecLoadSize,
AccessPattern>; ATileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution(); return TileEncodingPattern::Make2DStaticTileDistribution();
} }
// Tile: KPerBlock X MPerBlock // Tile: KPerBlock X MPerBlock
else else
{ {
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>(); constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr auto AccessPattern = tile_distribution_pattern::thread_raked;
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize, using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock, KPerBlock,
MPerBlock, MPerBlock,
VecLoadSize, VecLoadSize,
AccessPattern>; ATileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution(); return TileEncodingPattern::Make2DStaticTileDistribution();
} }
} }
...@@ -361,37 +497,44 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -361,37 +497,44 @@ struct UniversalGemmPipelineAgBgCrPolicy
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>(); constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr auto AccessPattern = tile_distribution_pattern::thread_raked;
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize, using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock, KPerBlock,
NPerBlock, NPerBlock,
VecLoadSize, VecLoadSize,
AccessPattern>; BTileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution(); return TileEncodingPattern::Make2DStaticTileDistribution();
} }
// Tile: NPerBlock X KPerBlock // Tile: NPerBlock X KPerBlock
else else
{ {
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>(); constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr auto AccessPattern = tile_distribution_pattern::thread_raked;
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize, using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
VecLoadSize, VecLoadSize,
AccessPattern>; BTileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution(); return TileEncodingPattern::Make2DStaticTileDistribution();
} }
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
{ {
using ALayout = remove_cvref_t<typename Problem::ALayout>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>); static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN; constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
#if 1
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
ATileAccessPattern>;
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
#else
constexpr index_t M1 = GetVectorSizeA<Problem>(); constexpr index_t M1 = GetVectorSizeA<Problem>();
constexpr index_t M0 = MPerBlock / M1; constexpr index_t M0 = MPerBlock / M1;
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
...@@ -428,53 +571,26 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -428,53 +571,26 @@ struct UniversalGemmPipelineAgBgCrPolicy
sequence<1, 2>, sequence<1, 2>,
sequence<1, 3>>{}); sequence<1, 3>>{});
} }
#endif
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
{ {
using BLayout = remove_cvref_t<typename Problem::BLayout>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>); static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr index_t N1 = GetVectorSizeB<Problem>(); using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t N0 = NPerBlock / N1; KPerBlock,
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; NPerBlock,
static_assert(total_pixels % N1 == 0); VecLoadSize,
constexpr index_t K3 = total_pixels / N1; BTileAccessPattern>;
constexpr index_t kKPack = GetSmemPackB<Problem>(); return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * N0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * N0);
constexpr index_t K0 = BlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
template <typename Problem> template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment