Commit 613e45b9 authored by root's avatar root
Browse files

cshuffle v2 result correct, but perf awful

parent 801f995c
...@@ -201,4 +201,15 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number ...@@ -201,4 +201,15 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return unpacks; return unpacks;
} }
template <typename StaticTensor>
CK_TILE_DEVICE void dump_static_tensor(StaticTensor& t){
constexpr auto span_2d = decltype(t)::get_distributed_spans();
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
printf("%f,", type_convert<float>(t(i_j_idx)));
});
printf("\n");
});
}
} // namespace ck_tile } // namespace ck_tile
...@@ -22,8 +22,8 @@ struct CShuffleEpilogueV2Problem ...@@ -22,8 +22,8 @@ struct CShuffleEpilogueV2Problem
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
// static constexpr bool UseRawStore = UseRawStore_; // static constexpr bool UseRawStore = UseRawStore_;
static constexpr index_t kBlockSize = kBlockSize_; static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t MPerBlock = kM_; static constexpr index_t kMPerBlock = kM_;
static constexpr index_t NPerBlock = kN_; static constexpr index_t kNPerBlock = kN_;
static constexpr bool kPadM = kPadM_; static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
}; };
...@@ -32,14 +32,12 @@ struct CShuffleEpilogueV2Problem ...@@ -32,14 +32,12 @@ struct CShuffleEpilogueV2Problem
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeOLdsBlockDescriptor()
{ {
static constexpr index_t kMPerBlock = Problem::MPerBlock; static constexpr index_t kMPerBlock = 64;
static constexpr index_t kNPerBlock = Problem::NPerBlock; static constexpr index_t kNPerBlock = Problem::kNPerBlock;
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
make_tuple(number<kNPerBlock>{}, number<1>{}), make_tuple(number<kNPerBlock>{}, number<1>{}));
number<1>{},
number<1>{});
} }
...@@ -47,8 +45,8 @@ CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; } ...@@ -47,8 +45,8 @@ CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeODramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeODramTileDistribution()
{ {
static constexpr index_t kMPerBlock = Problem::MPerBlock; static constexpr index_t kMPerBlock = 64;
static constexpr index_t kNPerBlock = Problem::NPerBlock; static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size(); static constexpr index_t WaveSize = get_warp_size();
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
...@@ -85,36 +83,17 @@ struct CShuffleEpilogueV2 ...@@ -85,36 +83,17 @@ struct CShuffleEpilogueV2
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore; static constexpr bool UseRawStore = Problem::UseRawStore;
static constexpr bool kMPerBlock = Problem::MPerBlock; static constexpr bool kMPerBlock = 64;
static constexpr bool kNPerBlock = Problem::NPerBlock; static constexpr bool kNPerBlock = Problem::kNPerBlock;
// constexpr auto a_warp_y_lengths = CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536;}//kMPerBlock * kNPerBlock * sizeof(ODataType); }
// to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_warp_y_lengths =
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
// merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// dst_out.set_y_sliced_thread_data(
// merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
// dst_warp_tensor.get_thread_buffer());
// });
// });
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return kMPerBlock * kNPerBlock * sizeof(ODataType); }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size // TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ? // how do we fix this ?
template <typename ODramWindowTmp, typename OAccTile> template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void *p_smem) CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void *p_smem)
{ {
block_sync_lds();
auto o_lds_tile = cast_tile<ODataType>(o_acc_tile); auto o_lds_tile = cast_tile<ODataType>(o_acc_tile);
constexpr auto o_lds_block_desc = MakeOLdsBlockDescriptor<Problem>(); constexpr auto o_lds_block_desc = MakeOLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(static_cast<ODataType*>(p_smem), o_lds_block_desc); auto o_lds_block = make_tensor_view<address_space_enum::lds>(static_cast<ODataType*>(p_smem), o_lds_block_desc);
...@@ -122,17 +101,6 @@ struct CShuffleEpilogueV2 ...@@ -122,17 +101,6 @@ struct CShuffleEpilogueV2
store_tile(o_lds_window0, o_lds_tile); store_tile(o_lds_window0, o_lds_tile);
block_sync_lds(); block_sync_lds();
// if (threadIdx.x == 0) {
// printf("%f, %f\n",type_convert<float>(static_cast<ODataType*>(p_smem)[32767]), type_convert<float>(static_cast<ODataType*>(p_smem)[32768]));
// constexpr auto span_2d = decltype(o_lds_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(o_lds_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
auto o_dram_distri = MakeODramTileDistribution<Problem>(); auto o_dram_distri = MakeODramTileDistribution<Problem>();
auto o_dram_tile = load_tile(make_tile_window(o_lds_window0, o_dram_distri)); auto o_dram_tile = load_tile(make_tile_window(o_lds_window0, o_dram_distri));
store_tile(o_dram_window_tmp, o_dram_tile); store_tile(o_dram_window_tmp, o_dram_tile);
......
...@@ -175,7 +175,7 @@ struct BlockGemmARegBRegCRegV2 ...@@ -175,7 +175,7 @@ struct BlockGemmARegBRegCRegV2
sequence<>, sequence<>,
tuple<sequence<MWarp>, sequence<NIterPerWarp, NWarp>>, tuple<sequence<MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>, tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>, tuple<sequence<0, 1>>,
sequence<2>, sequence<2>,
sequence<0>>{}; sequence<0>>{};
......
...@@ -213,16 +213,59 @@ struct GemmKernel ...@@ -213,16 +213,59 @@ struct GemmKernel
{i_m, i_n}); {i_m, i_n});
using CSubTileDistr = decltype(GemmPipeline::MakeCBlockSubTile()); using CSubTileDistr = decltype(GemmPipeline::MakeCBlockSubTile());
CSubTileDistr c_sub_tile;
// printf("!!!!!!!!!!!!!!!!!!!!");
// c_sub_tile.get_tile_distribution().print();
// if (threadIdx.x==0) {
// printf("!!!!!!!!!!!!!!!!!!!!~~~ %d %d\n", c_block_tile.get_tile_distribution().get_num_of_dimension_y(), c_sub_tile.get_tile_distribution().get_num_of_dimension_y());
// // c_block_tile.get_tile_distribution().print();
// constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// c_sub_y_index_zeros.print();
// c_sub_y_lengths.print();
// }
// auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
// auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
// if (threadIdx.x == 0) {
// c_sub_y_index_zeros.print();
// printf("\n");
// c_sub_y_lengths.print();
// printf("\n");
// printf("%d %d\n", GemmPipeline::NumCSubTile(), c_sub_tile.get_tile_distribution().get_num_of_dimension_y());
// }
// auto tbuf = c_block_tile.get_thread_buffer();
// for (index_t i = 0; i < tbuf.size(); i++) {
// if (threadIdx.x<16) {
// tbuf.set_as(i, float(threadIdx.x * 100 + i));
// } else {
// tbuf.set_as(i, float(threadIdx.x));
// }
// }
// c_block_tile.get_thread_buffer() = tbuf;
// if (threadIdx.x ==0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
static_for<0, GemmPipeline::NumCSubTile(), 1>{}([&](auto i_m0) static_for<0, GemmPipeline::NumCSubTile(), 1>{}([&](auto i_m0)
{ {
auto c_sub_tile = make_static_distributed_tensor<CDataType>(CSubTileDistr{}); constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<c_sub_tile.get_tile_distribution().get_num_of_dimension_y(), 0>{};
constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<CSubTileDistr::NDimY, 0>{}; constexpr auto c_sub_y_lengths = to_sequence(c_sub_tile.get_tile_distribution().get_ys_to_d_descriptor().get_lengths());
constexpr auto c_sub_y_lengths = to_sequence(CSubTileDistr{}.get_ys_to_d_descriptor().get_lengths());
c_sub_tile.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( c_sub_tile.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<i_m0>{}, c_sub_y_index_zeros), merge_sequences(sequence<i_m0>{}, c_sub_y_index_zeros),
merge_sequences(sequence<1>{}, c_sub_y_lengths)); merge_sequences(sequence<1>{}, c_sub_y_lengths));
EpiloguePipeline{}(CBlockWindow_pad, c_sub_tile, smem_ptr); EpiloguePipeline{}(CBlockWindow_pad, c_sub_tile, smem_ptr);
move_tile_window(CBlockWindow_pad, {TilePartitioner::kM / GemmPipeline::NumCSubTile(), 0}); move_tile_window(CBlockWindow_pad, {TilePartitioner::kM / GemmPipeline::NumCSubTile(), 0});
}); });
} }
}; };
......
...@@ -11,15 +11,13 @@ namespace ck_tile { ...@@ -11,15 +11,13 @@ namespace ck_tile {
// A Tile Window: global memory // A Tile Window: global memory
// B Tile Window: global memory // B Tile Window: global memory
// C Distributed tensor: register // C Distributed tensor: register
template <typename Problem, typename Policy_ = GemmPipelineAGmemBGmemCRegV1DefaultPolicy> template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct GemmPipelineAGmemBGmemCRegV1 struct GemmPipelineAGmemBGmemCRegV1
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using Policy = Policy_;
using Problem = Problem;
using ALayout = remove_cvref_t<typename Problem::ALayout>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment