Commit 532eb870 authored by coderfeli's avatar coderfeli
Browse files

fix warning and use default epilog and one out

parent 613e45b9
...@@ -48,14 +48,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -48,14 +48,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType, // using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType,
CDataType, // CDataType,
M_Warp * N_Warp * K_Warp * Warp_Size, // M_Warp * N_Warp * K_Warp * Warp_Size,
TilePartitioner::kM, // 64,
TilePartitioner::kN, // TilePartitioner::kN,
kPadM, // kPadM,
kPadN>>; // kPadN>>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
......
...@@ -32,8 +32,8 @@ struct CShuffleEpilogueV2Problem ...@@ -32,8 +32,8 @@ struct CShuffleEpilogueV2Problem
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeOLdsBlockDescriptor()
{ {
static constexpr index_t kMPerBlock = 64; constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock; constexpr index_t kNPerBlock = Problem::kNPerBlock;
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
...@@ -45,10 +45,10 @@ CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; } ...@@ -45,10 +45,10 @@ CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeODramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeODramTileDistribution()
{ {
static constexpr index_t kMPerBlock = 64; constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock; constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveSize = get_warp_size();
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
// using OLayout = remove_cvref_t<typename Problem::OLayout>; // using OLayout = remove_cvref_t<typename Problem::OLayout>;
...@@ -83,8 +83,9 @@ struct CShuffleEpilogueV2 ...@@ -83,8 +83,9 @@ struct CShuffleEpilogueV2
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore; static constexpr bool UseRawStore = Problem::UseRawStore;
static constexpr bool kMPerBlock = 64; static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr bool kNPerBlock = Problem::kNPerBlock; // static constexpr bool kMPerBlock = 64;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536;}//kMPerBlock * kNPerBlock * sizeof(ODataType); } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536;}//kMPerBlock * kNPerBlock * sizeof(ODataType); }
...@@ -104,6 +105,7 @@ struct CShuffleEpilogueV2 ...@@ -104,6 +105,7 @@ struct CShuffleEpilogueV2
auto o_dram_distri = MakeODramTileDistribution<Problem>(); auto o_dram_distri = MakeODramTileDistribution<Problem>();
auto o_dram_tile = load_tile(make_tile_window(o_lds_window0, o_dram_distri)); auto o_dram_tile = load_tile(make_tile_window(o_lds_window0, o_dram_distri));
store_tile(o_dram_window_tmp, o_dram_tile); store_tile(o_dram_window_tmp, o_dram_tile);
block_sync_lds();
} }
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -212,61 +212,21 @@ struct GemmKernel ...@@ -212,61 +212,21 @@ struct GemmKernel
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
using CSubTileDistr = decltype(GemmPipeline::MakeCBlockSubTile()); EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
CSubTileDistr c_sub_tile; // using CSubTileDistr = decltype(GemmPipeline::MakeCBlockSubTile());
// printf("!!!!!!!!!!!!!!!!!!!!");
// c_sub_tile.get_tile_distribution().print(); // static_for<0, GemmPipeline::NumCSubTile(), 1>{}([&](auto i_m0)
// if (threadIdx.x==0) { // {
// printf("!!!!!!!!!!!!!!!!!!!!~~~ %d %d\n", c_block_tile.get_tile_distribution().get_num_of_dimension_y(), c_sub_tile.get_tile_distribution().get_num_of_dimension_y()); // CSubTileDistr c_sub_tile;
// // c_block_tile.get_tile_distribution().print();
// constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{}; // constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths()); // constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// c_sub_y_index_zeros.print(); // c_sub_tile.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
// c_sub_y_lengths.print(); // merge_sequences(sequence<i_m0>{}, c_sub_y_index_zeros),
// } // merge_sequences(sequence<1>{}, c_sub_y_lengths));
// auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// if (threadIdx.x == 0) {
// c_sub_y_index_zeros.print();
// printf("\n");
// c_sub_y_lengths.print();
// printf("\n");
// printf("%d %d\n", GemmPipeline::NumCSubTile(), c_sub_tile.get_tile_distribution().get_num_of_dimension_y());
// }
// auto tbuf = c_block_tile.get_thread_buffer();
// for (index_t i = 0; i < tbuf.size(); i++) {
// if (threadIdx.x<16) {
// tbuf.set_as(i, float(threadIdx.x * 100 + i));
// } else {
// tbuf.set_as(i, float(threadIdx.x));
// }
// }
// c_block_tile.get_thread_buffer() = tbuf;
// if (threadIdx.x ==0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
static_for<0, GemmPipeline::NumCSubTile(), 1>{}([&](auto i_m0)
{
constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
c_sub_tile.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<i_m0>{}, c_sub_y_index_zeros),
merge_sequences(sequence<1>{}, c_sub_y_lengths));
EpiloguePipeline{}(CBlockWindow_pad, c_sub_tile, smem_ptr); // EpiloguePipeline{}(CBlockWindow_pad, c_sub_tile, smem_ptr);
move_tile_window(CBlockWindow_pad, {TilePartitioner::kM / GemmPipeline::NumCSubTile(), 0}); // move_tile_window(CBlockWindow_pad, {TilePartitioner::kM / GemmPipeline::NumCSubTile(), 0});
// });
});
} }
}; };
......
...@@ -249,36 +249,6 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -249,36 +249,6 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
// if (threadIdx.x == 64) {
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f, %f; ", type_convert<float>(a_block_tile0(i_j_idx)), type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// if (threadIdx.x == 0) {
// printf("aalds\n");
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbblds\n");
// constexpr auto span_2d2 = decltype(b_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d2[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d2[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// LDS write 1 // LDS write 1
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
...@@ -321,29 +291,30 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -321,29 +291,30 @@ struct GemmPipelineAGmemBGmemCRegV1
} }
//tail 3 //tail 3
if (iCounter == 1) { // if (iCounter == 1) {
// 3 // // 3
{ // {
block_sync_lds(); // block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1); // Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1); // Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func); // LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func); // LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); // block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
} // }
// 2 // // 2
{ // {
block_sync_lds(); // block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0); // Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0); // Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); // block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
} // }
//1 // //1
{ // {
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); // block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
} // }
//tail 2 // //tail 2
} else { // } else
{
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1); Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
......
...@@ -70,6 +70,21 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -70,6 +70,21 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return a_lds_block_desc; return a_lds_block_desc;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockLinearDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock) * 8>{}, number<8>{}, number<1>{}),
number<8>{},
number<1>{});
return a_lds_block_desc_0;
}
// 3d + padding // 3d + padding
template <typename Problem> template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment