Commit 801f995c authored by coderfeli's avatar coderfeli
Browse files

tmp:add smem cshuffle code but not debug

parent 5a2d93d4
...@@ -22,9 +22,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -22,9 +22,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr bool kPadN = false; constexpr bool kPadN = false;
constexpr bool kPadK = false; constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -36,6 +33,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -36,6 +33,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t Warp_Size = 64;
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
...@@ -43,8 +41,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -43,8 +41,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Whether doing the CShuffle (transpose before the global memory), depending on the output // Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout. // layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape = using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
...@@ -52,21 +48,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -52,21 +48,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType,
using GemmEpilogue = std::conditional_t< CDataType,
CShuffleEpilogue, M_Warp * N_Warp * K_Warp * Warp_Size,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType, TilePartitioner::kM,
CDataType, TilePartitioner::kN,
kPadM, kPadM,
kPadN, kPadN>>;
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue_v2.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// this epilogue just store out a M*N matrix, row major
template <typename AccDataType_,
typename ODataType_,
index_t kBlockSize_,
index_t kM_,
index_t kN_,
bool kPadM_,
bool kPadN_>
struct CShuffleEpilogueV2Problem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
// static constexpr bool UseRawStore = UseRawStore_;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t MPerBlock = kM_;
static constexpr index_t NPerBlock = kN_;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOLdsBlockDescriptor()
{
static constexpr index_t kMPerBlock = Problem::MPerBlock;
static constexpr index_t kNPerBlock = Problem::NPerBlock;
return make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
make_tuple(number<kNPerBlock>{}, number<1>{}),
number<1>{},
number<1>{});
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeODramTileDistribution()
{
static constexpr index_t kMPerBlock = Problem::MPerBlock;
static constexpr index_t kNPerBlock = Problem::NPerBlock;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
using ODataType = remove_cvref_t<typename Problem::ODataType>;
// using OLayout = remove_cvref_t<typename Problem::OLayout>;
// if constexpr(std::is_same_v<OLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
// {
// static_assert(0, "not impl");
// }
constexpr index_t N2 = 8;
constexpr index_t N1 = min(kNPerBlock / N2, WaveSize);
constexpr index_t N0 = integer_divide_ceil(kNPerBlock / N2, WaveSize);
constexpr index_t M2 = integer_divide_ceil(WaveSize, kNPerBlock / N2);
constexpr index_t M1 = BlockSize / WaveSize;
constexpr index_t M0 = integer_divide_ceil(kMPerBlock, M1 * M2);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1, N2>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogueV2
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore;
static constexpr bool kMPerBlock = Problem::MPerBlock;
static constexpr bool kNPerBlock = Problem::NPerBlock;
// constexpr auto a_warp_y_lengths =
// to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_warp_y_lengths =
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
// merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// dst_out.set_y_sliced_thread_data(
// merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
// dst_warp_tensor.get_thread_buffer());
// });
// });
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return kMPerBlock * kNPerBlock * sizeof(ODataType); }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void *p_smem)
{
auto o_lds_tile = cast_tile<ODataType>(o_acc_tile);
constexpr auto o_lds_block_desc = MakeOLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(static_cast<ODataType*>(p_smem), o_lds_block_desc);
auto o_lds_window0 = make_tile_window(o_lds_block, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0, 0});
store_tile(o_lds_window0, o_lds_tile);
block_sync_lds();
// if (threadIdx.x == 0) {
// printf("%f, %f\n",type_convert<float>(static_cast<ODataType*>(p_smem)[32767]), type_convert<float>(static_cast<ODataType*>(p_smem)[32768]));
// constexpr auto span_2d = decltype(o_lds_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(o_lds_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
auto o_dram_distri = MakeODramTileDistribution<Problem>();
auto o_dram_tile = load_tile(make_tile_window(o_lds_window0, o_dram_distri));
store_tile(o_dram_window_tmp, o_dram_tile);
}
};
} // namespace ck_tile
...@@ -169,6 +169,23 @@ struct BlockGemmARegBRegCRegV2 ...@@ -169,6 +169,23 @@ struct BlockGemmARegBRegCRegV2
return c_block_tensor; return c_block_tensor;
} }
CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2>,
sequence<0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_DEVICE static constexpr auto MakeABlockDistribution() CK_TILE_DEVICE static constexpr auto MakeABlockDistribution()
{ {
// M->N Warp // M->N Warp
......
...@@ -211,8 +211,19 @@ struct GemmKernel ...@@ -211,8 +211,19 @@ struct GemmKernel
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile); using CSubTileDistr = decltype(GemmPipeline::MakeCBlockSubTile());
static_for<0, GemmPipeline::NumCSubTile(), 1>{}([&](auto i_m0)
{
auto c_sub_tile = make_static_distributed_tensor<CDataType>(CSubTileDistr{});
constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<CSubTileDistr::NDimY, 0>{};
constexpr auto c_sub_y_lengths = to_sequence(CSubTileDistr{}.get_ys_to_d_descriptor().get_lengths());
c_sub_tile.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<i_m0>{}, c_sub_y_index_zeros),
merge_sequences(sequence<1>{}, c_sub_y_lengths));
EpiloguePipeline{}(CBlockWindow_pad, c_sub_tile, smem_ptr);
move_tile_window(CBlockWindow_pad, {TilePartitioner::kM / GemmPipeline::NumCSubTile(), 0});
});
} }
}; };
......
...@@ -11,13 +11,15 @@ namespace ck_tile { ...@@ -11,13 +11,15 @@ namespace ck_tile {
// A Tile Window: global memory // A Tile Window: global memory
// B Tile Window: global memory // B Tile Window: global memory
// C Distributed tensor: register // C Distributed tensor: register
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy> template <typename Problem, typename Policy_ = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct GemmPipelineAGmemBGmemCRegV1 struct GemmPipelineAGmemBGmemCRegV1
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using Policy = Policy_;
using Problem = Problem;
using ALayout = remove_cvref_t<typename Problem::ALayout>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
...@@ -133,6 +135,16 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -133,6 +135,16 @@ struct GemmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile() {
return Policy::template BlockGemm<Problem>::MakeCBlockSubTile();
}
CK_TILE_DEVICE static constexpr auto NumCSubTile() {
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2
return integer_divide_ceil(kMPerBlock, WaveNumM * MPerXDL);
}
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
...@@ -180,26 +192,6 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -180,26 +192,6 @@ struct GemmPipelineAGmemBGmemCRegV1
// global read 0 // global read 0
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbb\n");
// constexpr auto span_2d2 = decltype(b_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
////////////// LDS desc, window & register ///////////////// ////////////// LDS desc, window & register /////////////////
// AB LDS desc // AB LDS desc
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>(); constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
...@@ -363,18 +355,16 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -363,18 +355,16 @@ struct GemmPipelineAGmemBGmemCRegV1
// 2 // 2
{ {
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
// if (threadIdx.x == 64) {
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f, %f; %f, %f. ", type_convert<float>(a_block_tile0(i_j_idx)), type_convert<float>(b_block_tile0(i_j_idx)), type_convert<float>(a_block_tile1(i_j_idx)), type_convert<float>(b_block_tile1(i_j_idx)));
// });
// printf("\n");
// });
// }
} }
} }
/// cccccccccc
// constexpr auto c_lds_block_desc = Policy::template MakeCLdsBlockDescriptor<Problem>();
// auto c_lds_block = make_tensor_view<address_space_enum::lds>(reinterpret_cast<CDataType*>(p_smem), c_lds_block_desc);
// auto c_lds_window0 = make_tile_window(c_lds_block, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0, 0});
// store_tile(c_lds_window0, c_block_tile);
// block_sync_lds();
return c_block_tile; return c_block_tile;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment