Commit 6fd51c43 authored by coderfeli's avatar coderfeli
Browse files

rm useless code

parent 8d2f2f8c
......@@ -11,6 +11,13 @@ namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
// diff from v1:
// 1. use mwarp x nwarp = 2x2
// 2. use 32x32x16 block gemm
// 3. expose a lds, b lds distribution
// 4. impl a subtile for output c shuffle sub tile construct
// 5. reformat some code.
// todo: merge these using universal gemm
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV2
{
......@@ -44,22 +51,8 @@ struct BlockGemmARegBRegCRegV2
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
// M->N Warp
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = MakeABlockDistribution();
constexpr auto b_block_dstr_encode = MakeBBlockDistribution();
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
......@@ -69,12 +62,6 @@ struct BlockGemmARegBRegCRegV2
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
......@@ -169,36 +156,29 @@ struct BlockGemmARegBRegCRegV2
return c_block_tensor;
}
CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 1>>,
sequence<2>,
sequence<0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// for cshuffle, disable currently
// CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile()
// {
// constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
// sequence<>,
// tuple<sequence<MWarp>, sequence<NIterPerWarp, NWarp>>,
// tuple<sequence<1, 2>>,
// tuple<sequence<0, 1>>,
// sequence<2>,
// sequence<0>>{};
// constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
// auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
// return c_block_tensor;
// }
CK_TILE_DEVICE static constexpr auto MakeABlockDistribution()
{
// M->N Warp
// using AWarpDstrEncoding = tile_distribution_encoding<
// sequence<>,
// tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, //<32>, <2, 8>
// tuple<sequence<2, 1>>,
// tuple<sequence<0, 0>>,
// sequence<2>,
// sequence<1>>;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, // <4, 2>, <2>
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
......@@ -224,16 +204,6 @@ struct BlockGemmARegBRegCRegV2
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
return b_block_dstr;
// return make_static_distributed_tensor<BDataType>(b_block_dstr);
}
// Prefetch lds
template <typename BlockWindow, typename BlockTensor>
CK_TILE_DEVICE static void PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor)
{
auto tileDist = BlockTensor::get_tile_distribution();
// load_tile(block_tensor, make_tile_window(block_window, tileDist));
load_tile(block_tensor, make_tile_window_linear(block_window, tileDist));
}
// C = A * B
......
......@@ -25,7 +25,6 @@ struct GemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
......@@ -214,20 +213,6 @@ struct GemmKernel
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
// using CSubTileDistr = decltype(GemmPipeline::MakeCBlockSubTile());
// static_for<0, GemmPipeline::NumCSubTile(), 1>{}([&](auto i_m0)
// {
// CSubTileDistr c_sub_tile;
// constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// c_sub_tile.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
// merge_sequences(sequence<i_m0>{}, c_sub_y_index_zeros),
// merge_sequences(sequence<1>{}, c_sub_y_lengths));
// EpiloguePipeline{}(CBlockWindow_pad, c_sub_tile, smem_ptr);
// move_tile_window(CBlockWindow_pad, {TilePartitioner::kM / GemmPipeline::NumCSubTile(), 0});
// });
}
};
......
......@@ -39,18 +39,6 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr bool kHasHotLoop = Problem::kHasHotLoop;
static constexpr auto kTailNum = Problem::kTailNum;
// CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
// {
// return integer_least_multiple(
// sizeof(ADataType) *
// Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
// 16) * 2 +
// integer_least_multiple(
// sizeof(BDataType) *
// Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
// 16) * 2;
// }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
......@@ -58,7 +46,7 @@ struct GemmPipelineAGmemBGmemCRegV1
template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE static void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window)
SrcTileWindow& dram_tile_window)
{
load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, kKPerBlock});
......@@ -66,84 +54,80 @@ struct GemmPipelineAGmemBGmemCRegV1
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE static void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func)
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func)
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE static void LocalPrefetch(DstBlockTile& dst_block_tile,
const SrcTileWindow& lds_tile_window)
{
load_tile(dst_block_tile, lds_tile_window);
}
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
// schedule
// constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32
// constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});//32
// constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});//8
// constexpr index_t WaveSize = 64;
// constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2
// constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});//2
// constexpr index_t A_LDS_Read_Width = KPerXDL;//8
// constexpr index_t B_LDS_Read_Width = KPerXDL;//8
// constexpr index_t num_buffer_load_inst_a =
// kMPerBlock * kKPerBlock / (BlockSize * VectorSizeA); // 4
// constexpr index_t num_buffer_load_inst_b =
// kNPerBlock * kKPerBlock / (BlockSize * VectorSizeB); // 4
// constexpr index_t num_ds_write_inst_a = kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
// constexpr index_t num_ds_write_inst_b = kNPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
// constexpr index_t A_LDS_Read_Inst_Num =
// WaveNumN * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
// constexpr index_t B_LDS_Read_Inst_Num =
// WaveNumM * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
// constexpr index_t num_mfma_inst = kMPerBlock * kNPerBlock * kKPerBlock /
// (BlockSize / WaveSize) /
// (MPerXDL * NPerXDL * KPerXDL); // 64
// // A/B split schedule
// // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
// constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
// ? A_LDS_Read_Inst_Num
// : A_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
// ? B_LDS_Read_Inst_Num
// : B_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; // 16
// constexpr auto num_ds_write_inst = num_ds_write_inst_a + num_ds_write_inst_b; //8
// constexpr auto num_buffer_load_inst = num_buffer_load_inst_a + num_buffer_load_inst_b; //8
// constexpr auto num_issue = num_buffer_load_inst; // 8
// static_for<0, num_issue, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(
// 0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
// __builtin_amdgcn_sched_group_barrier(
// 0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
// });
// __builtin_amdgcn_sched_barrier(0);
static_for<0, 8, 1>{}([&](auto i) {
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});//32
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});//8
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});//2
constexpr index_t A_LDS_Read_Width = KPerXDL;//8
constexpr index_t B_LDS_Read_Width = KPerXDL;//8
constexpr index_t num_buffer_load_inst_a =
kMPerBlock * kKPerBlock / (BlockSize * VectorSizeA); // 4
constexpr index_t num_buffer_load_inst_b =
kNPerBlock * kKPerBlock / (BlockSize * VectorSizeB); // 4
constexpr index_t num_ds_write_inst_a = kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
constexpr index_t num_ds_write_inst_b = kNPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 4
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * kMPerBlock * kKPerBlock / (BlockSize * KPerXDL); // 8
constexpr index_t num_mfma_inst = kMPerBlock * kNPerBlock * kKPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL); // 64
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; // 16
constexpr auto num_ds_write_inst = num_ds_write_inst_a + num_ds_write_inst_b; //8
constexpr auto num_buffer_load_inst = num_buffer_load_inst_a + num_buffer_load_inst_b; //8
constexpr auto num_issue = num_buffer_load_inst; // 8
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(0x008, 5, 0); // MFMA : 5
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_issue - 3, 0); // MFMA : 5
});
__builtin_amdgcn_sched_barrier(0);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile() {
......@@ -226,15 +210,15 @@ struct GemmPipelineAGmemBGmemCRegV1
auto b_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
// A LDS tile window for store
auto a_lds_window0 = make_tile_window_linear(
auto a_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ABlockTileDistr);
auto a_lds_window1 = make_tile_window_linear(
auto a_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ABlockTileDistr);
// B LDS tile window for store
auto b_lds_window0 = make_tile_window_linear(
auto b_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BBlockTileDistr);
auto b_lds_window1 = make_tile_window_linear(
auto b_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BBlockTileDistr);
// Block GEMM
......@@ -253,8 +237,6 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_sync_lds();
// local prefetch 0
// a b register tile for lds prefetch & mfma
constexpr auto ALdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeABlockDistribution()){};
constexpr auto BLdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeBBlockDistribution()){};
......@@ -267,8 +249,10 @@ struct GemmPipelineAGmemBGmemCRegV1
auto b_lds_ld_window0 = make_tile_window_linear(b_lds_ld_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BLdsTileDistr);
auto b_lds_ld_window1 = make_tile_window_linear(b_lds_ld_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}, BLdsTileDistr);
load_tile(a_block_tile0, a_lds_ld_window0);
load_tile(b_block_tile0, b_lds_ld_window0);
// local prefetch 0
// a b register tile for lds prefetch & mfma
LocalPrefetch(a_block_tile0, a_lds_ld_window0);
LocalPrefetch(b_block_tile0, b_lds_ld_window0);
// LDS write 1
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
......@@ -278,43 +262,43 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
index_t iCounter = __builtin_amdgcn_readfirstlane(num_loop - 2);
ALdsTile a_block_tile1;
BLdsTile b_block_tile1;
if (kHasHotLoop) {
index_t iCounter = __builtin_amdgcn_readfirstlane(num_loop - 2);
do
{
// ping
{
block_sync_lds();
//prefetch lds -> vgpr
load_tile(a_block_tile1, a_lds_ld_window1);
load_tile(b_block_tile1, b_lds_ld_window1);
LocalPrefetch(a_block_tile1, a_lds_ld_window1);
LocalPrefetch(b_block_tile1, b_lds_ld_window1);
//prefill -> lds
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
//prefill global -> vgpr
// GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
// GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
load_tile(a_global_load_tile, a_copy_dram_window);
load_tile(b_global_load_tile, b_copy_dram_window);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
// gemm
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
// pong
{
block_sync_lds();
load_tile(a_block_tile0, a_lds_ld_window0);
load_tile(b_block_tile0, b_lds_ld_window0);
//prefetch lds -> vgpr
LocalPrefetch(a_block_tile0, a_lds_ld_window0);
LocalPrefetch(b_block_tile0, b_lds_ld_window0);
//prefill -> lds
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
//prefill global -> vgpr
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
// gemm
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
......@@ -328,9 +312,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3
{
block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(a_block_tile1, a_lds_ld_window1);
load_tile(b_block_tile1, b_lds_ld_window1);
LocalPrefetch(a_block_tile1, a_lds_ld_window1);
LocalPrefetch(b_block_tile1, b_lds_ld_window1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
......@@ -338,14 +321,14 @@ struct GemmPipelineAGmemBGmemCRegV1
// 2
{
block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(a_block_tile0, a_lds_ld_window0);
load_tile(b_block_tile0, b_lds_ld_window0);
LocalPrefetch(a_block_tile0, a_lds_ld_window0);
LocalPrefetch(b_block_tile0, b_lds_ld_window0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
//1
{
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
}
else
......@@ -353,9 +336,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// //tail 2
{
block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(a_block_tile1, a_lds_ld_window1);
load_tile(b_block_tile1, b_lds_ld_window1);
LocalPrefetch(a_block_tile1, a_lds_ld_window1);
LocalPrefetch(b_block_tile1, b_lds_ld_window1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
static_for<0, 8, 1>{}([&](auto i) {
ignore = i;
......@@ -365,21 +347,11 @@ struct GemmPipelineAGmemBGmemCRegV1
});
__builtin_amdgcn_sched_barrier(0);
}
// 2
{
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
__builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
}
}
/// cccccccccc
// constexpr auto c_lds_block_desc = Policy::template MakeCLdsBlockDescriptor<Problem>();
// auto c_lds_block = make_tensor_view<address_space_enum::lds>(reinterpret_cast<CDataType*>(p_smem), c_lds_block_desc);
// auto c_lds_window0 = make_tile_window(c_lds_block, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0, 0});
// store_tile(c_lds_window0, c_block_tile);
// block_sync_lds();
return c_block_tile;
}
......@@ -401,25 +373,4 @@ struct GemmPipelineAGmemBGmemCRegV1
}
};
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if(abs(type_convert<float>(c_block_tile(i_j_idx)) - 32) > 0.1)
// printf("%d %f,", threadIdx.x, type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
} // namespace ck_tile
......@@ -16,37 +16,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem>
using BlockGemm = BlockGemmARegBRegCRegV2<Problem, BlockGemmPolicy>;
#if 0
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{});
return a_lds_block_desc;
}
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{});
return b_lds_block_desc;
}
#elif 1
#if 1
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
......@@ -88,8 +58,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
// make_tuple(make_pass_through_transform(kNPerBlock),
// make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / 8>{}, number<8>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
......@@ -135,76 +103,76 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using BDataType = remove_cvref_t<typename Problem::BDataType>;
return Problem::VectorLoadSize / sizeof(BDataType);
}
#elif 1
#else
// fake XOR
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
// {
// using namespace ck_tile;
// using ADataType = remove_cvref_t<typename Problem::ADataType>;
// constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
// constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
// make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
// number<kKPerBlock>{});
// constexpr index_t kK1 = 16 / sizeof(ADataType);
// constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
// a_lds_block_desc_d1_d2_d3,
// make_tuple(
// make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
// make_pass_through_transform(2)),
// make_tuple(sequence<0, 2>{}, sequence<1>{}),
// make_tuple(sequence<0, 2>{}, sequence<1>{}));
// constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
// a_lds_block_desc_d4_d5_d6,
// make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
// make_pass_through_transform(kKPerBlock)),
// make_tuple(sequence<0, 1>{}, sequence<2>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
// return a_lds_block_desc_m_k;
// }
// // fake XOR
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
// {
// using namespace ck_tile;
// using BDataType = remove_cvref_t<typename Problem::BDataType>;
// constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
// constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
// make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
// number<kKPerBlock>{});
// constexpr index_t kK1 = 16 / sizeof(BDataType);
// constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
// b_lds_block_desc_d1_d2_d3,
// make_tuple(
// make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
// make_pass_through_transform(2)),
// make_tuple(sequence<0, 2>{}, sequence<1>{}),
// make_tuple(sequence<0, 2>{}, sequence<1>{}));
// constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
// b_lds_block_desc_d4_d5_d6,
// make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
// make_pass_through_transform(kKPerBlock)),
// make_tuple(sequence<0, 1>{}, sequence<2>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
// return b_lds_block_desc_n_k;
// }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(ADataType);
constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
a_lds_block_desc_d1_d2_d3,
make_tuple(
make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
make_pass_through_transform(kKPerBlock)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc_m_k;
}
// fake XOR
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(BDataType);
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
b_lds_block_desc_d1_d2_d3,
make_tuple(
make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
make_pass_through_transform(kKPerBlock)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc_n_k;
}
#endif
template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment