Commit 532eb870 authored by coderfeli's avatar coderfeli
Browse files

fix warning and use default epilog and one out

parent 613e45b9
......@@ -48,14 +48,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType,
CDataType,
M_Warp * N_Warp * K_Warp * Warp_Size,
TilePartitioner::kM,
TilePartitioner::kN,
kPadM,
kPadN>>;
// using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType,
// CDataType,
// M_Warp * N_Warp * K_Warp * Warp_Size,
// 64,
// TilePartitioner::kN,
// kPadM,
// kPadN>>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
......
......@@ -32,8 +32,8 @@ struct CShuffleEpilogueV2Problem
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOLdsBlockDescriptor()
{
static constexpr index_t kMPerBlock = 64;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
constexpr index_t kMPerBlock = Problem::kMPerBlock;
constexpr index_t kNPerBlock = Problem::kNPerBlock;
return make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
......@@ -45,10 +45,10 @@ CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeODramTileDistribution()
{
static constexpr index_t kMPerBlock = 64;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
constexpr index_t kMPerBlock = Problem::kMPerBlock;
constexpr index_t kNPerBlock = Problem::kNPerBlock;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
using ODataType = remove_cvref_t<typename Problem::ODataType>;
// using OLayout = remove_cvref_t<typename Problem::OLayout>;
......@@ -83,8 +83,9 @@ struct CShuffleEpilogueV2
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore;
static constexpr bool kMPerBlock = 64;
static constexpr bool kNPerBlock = Problem::kNPerBlock;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
// static constexpr bool kMPerBlock = 64;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536;}//kMPerBlock * kNPerBlock * sizeof(ODataType); }
......@@ -104,6 +105,7 @@ struct CShuffleEpilogueV2
auto o_dram_distri = MakeODramTileDistribution<Problem>();
auto o_dram_tile = load_tile(make_tile_window(o_lds_window0, o_dram_distri));
store_tile(o_dram_window_tmp, o_dram_tile);
block_sync_lds();
}
};
} // namespace ck_tile
......@@ -212,61 +212,21 @@ struct GemmKernel
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
using CSubTileDistr = decltype(GemmPipeline::MakeCBlockSubTile());
CSubTileDistr c_sub_tile;
// printf("!!!!!!!!!!!!!!!!!!!!");
// c_sub_tile.get_tile_distribution().print();
// if (threadIdx.x==0) {
// printf("!!!!!!!!!!!!!!!!!!!!~~~ %d %d\n", c_block_tile.get_tile_distribution().get_num_of_dimension_y(), c_sub_tile.get_tile_distribution().get_num_of_dimension_y());
// // c_block_tile.get_tile_distribution().print();
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
// using CSubTileDistr = decltype(GemmPipeline::MakeCBlockSubTile());
// static_for<0, GemmPipeline::NumCSubTile(), 1>{}([&](auto i_m0)
// {
// CSubTileDistr c_sub_tile;
// constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// c_sub_y_index_zeros.print();
// c_sub_y_lengths.print();
// }
// auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// if (threadIdx.x == 0) {
// c_sub_y_index_zeros.print();
// printf("\n");
// c_sub_y_lengths.print();
// printf("\n");
// printf("%d %d\n", GemmPipeline::NumCSubTile(), c_sub_tile.get_tile_distribution().get_num_of_dimension_y());
// }
// auto tbuf = c_block_tile.get_thread_buffer();
// for (index_t i = 0; i < tbuf.size(); i++) {
// if (threadIdx.x<16) {
// tbuf.set_as(i, float(threadIdx.x * 100 + i));
// } else {
// tbuf.set_as(i, float(threadIdx.x));
// }
// }
// c_block_tile.get_thread_buffer() = tbuf;
// if (threadIdx.x ==0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
static_for<0, GemmPipeline::NumCSubTile(), 1>{}([&](auto i_m0)
{
constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
c_sub_tile.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<i_m0>{}, c_sub_y_index_zeros),
merge_sequences(sequence<1>{}, c_sub_y_lengths));
// c_sub_tile.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
// merge_sequences(sequence<i_m0>{}, c_sub_y_index_zeros),
// merge_sequences(sequence<1>{}, c_sub_y_lengths));
EpiloguePipeline{}(CBlockWindow_pad, c_sub_tile, smem_ptr);
move_tile_window(CBlockWindow_pad, {TilePartitioner::kM / GemmPipeline::NumCSubTile(), 0});
});
// EpiloguePipeline{}(CBlockWindow_pad, c_sub_tile, smem_ptr);
// move_tile_window(CBlockWindow_pad, {TilePartitioner::kM / GemmPipeline::NumCSubTile(), 0});
// });
}
};
......
......@@ -249,36 +249,6 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
// if (threadIdx.x == 64) {
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f, %f; ", type_convert<float>(a_block_tile0(i_j_idx)), type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// if (threadIdx.x == 0) {
// printf("aalds\n");
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbblds\n");
// constexpr auto span_2d2 = decltype(b_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d2[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d2[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// LDS write 1
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
......@@ -321,29 +291,30 @@ struct GemmPipelineAGmemBGmemCRegV1
}
//tail 3
if (iCounter == 1) {
// 3
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
// 2
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
//1
{
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
//tail 2
} else {
// if (iCounter == 1) {
// // 3
// {
// block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
// Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
// LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
// LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// }
// // 2
// {
// block_sync_lds();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
// Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
// block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
// }
// //1
// {
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// }
// //tail 2
// } else
{
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
......
......@@ -70,6 +70,21 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockLinearDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock) * 8>{}, number<8>{}, number<1>{}),
number<8>{},
number<1>{});
return a_lds_block_desc_0;
}
// 3d + padding
template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment