Commit 801f995c authored by coderfeli's avatar coderfeli
Browse files

tmp:add smem cshuffle code but not debug

parent 5a2d93d4
......@@ -22,9 +22,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1;
......@@ -36,6 +33,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t Warp_Size = 64;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
......@@ -43,8 +41,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
......@@ -52,21 +48,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
using GemmEpilogue = ck_tile::CShuffleEpilogueV2<ck_tile::CShuffleEpilogueV2Problem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
M_Warp * N_Warp * K_Warp * Warp_Size,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
TilePartitioner::kN,
kPadM,
kPadN>>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue_v2.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// this epilogue just store out a M*N matrix, row major
template <typename AccDataType_,
typename ODataType_,
index_t kBlockSize_,
index_t kM_,
index_t kN_,
bool kPadM_,
bool kPadN_>
struct CShuffleEpilogueV2Problem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
// static constexpr bool UseRawStore = UseRawStore_;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t MPerBlock = kM_;
static constexpr index_t NPerBlock = kN_;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOLdsBlockDescriptor()
{
static constexpr index_t kMPerBlock = Problem::MPerBlock;
static constexpr index_t kNPerBlock = Problem::NPerBlock;
return make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
make_tuple(number<kNPerBlock>{}, number<1>{}),
number<1>{},
number<1>{});
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeODramTileDistribution()
{
static constexpr index_t kMPerBlock = Problem::MPerBlock;
static constexpr index_t kNPerBlock = Problem::NPerBlock;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
using ODataType = remove_cvref_t<typename Problem::ODataType>;
// using OLayout = remove_cvref_t<typename Problem::OLayout>;
// if constexpr(std::is_same_v<OLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
// {
// static_assert(0, "not impl");
// }
constexpr index_t N2 = 8;
constexpr index_t N1 = min(kNPerBlock / N2, WaveSize);
constexpr index_t N0 = integer_divide_ceil(kNPerBlock / N2, WaveSize);
constexpr index_t M2 = integer_divide_ceil(WaveSize, kNPerBlock / N2);
constexpr index_t M1 = BlockSize / WaveSize;
constexpr index_t M0 = integer_divide_ceil(kMPerBlock, M1 * M2);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1, N2>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogueV2
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore;
static constexpr bool kMPerBlock = Problem::MPerBlock;
static constexpr bool kNPerBlock = Problem::NPerBlock;
// constexpr auto a_warp_y_lengths =
// to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_warp_y_lengths =
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
// merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// dst_out.set_y_sliced_thread_data(
// merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
// dst_warp_tensor.get_thread_buffer());
// });
// });
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return kMPerBlock * kNPerBlock * sizeof(ODataType); }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void *p_smem)
{
auto o_lds_tile = cast_tile<ODataType>(o_acc_tile);
constexpr auto o_lds_block_desc = MakeOLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(static_cast<ODataType*>(p_smem), o_lds_block_desc);
auto o_lds_window0 = make_tile_window(o_lds_block, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0, 0});
store_tile(o_lds_window0, o_lds_tile);
block_sync_lds();
// if (threadIdx.x == 0) {
// printf("%f, %f\n",type_convert<float>(static_cast<ODataType*>(p_smem)[32767]), type_convert<float>(static_cast<ODataType*>(p_smem)[32768]));
// constexpr auto span_2d = decltype(o_lds_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(o_lds_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
auto o_dram_distri = MakeODramTileDistribution<Problem>();
auto o_dram_tile = load_tile(make_tile_window(o_lds_window0, o_dram_distri));
store_tile(o_dram_window_tmp, o_dram_tile);
}
};
} // namespace ck_tile
......@@ -169,6 +169,23 @@ struct BlockGemmARegBRegCRegV2
return c_block_tensor;
}
CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2>,
sequence<0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_DEVICE static constexpr auto MakeABlockDistribution()
{
// M->N Warp
......
......@@ -212,7 +212,18 @@ struct GemmKernel
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
using CSubTileDistr = decltype(GemmPipeline::MakeCBlockSubTile());
static_for<0, GemmPipeline::NumCSubTile(), 1>{}([&](auto i_m0)
{
auto c_sub_tile = make_static_distributed_tensor<CDataType>(CSubTileDistr{});
constexpr auto c_sub_y_index_zeros = uniform_sequence_gen_t<CSubTileDistr::NDimY, 0>{};
constexpr auto c_sub_y_lengths = to_sequence(CSubTileDistr{}.get_ys_to_d_descriptor().get_lengths());
c_sub_tile.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<i_m0>{}, c_sub_y_index_zeros),
merge_sequences(sequence<1>{}, c_sub_y_lengths));
EpiloguePipeline{}(CBlockWindow_pad, c_sub_tile, smem_ptr);
move_tile_window(CBlockWindow_pad, {TilePartitioner::kM / GemmPipeline::NumCSubTile(), 0});
});
}
};
......
......@@ -11,13 +11,15 @@ namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
template <typename Problem, typename Policy_ = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct GemmPipelineAGmemBGmemCRegV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using Policy = Policy_;
using Problem = Problem;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
......@@ -133,6 +135,16 @@ struct GemmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_barrier(0);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockSubTile() {
return Policy::template BlockGemm<Problem>::MakeCBlockSubTile();
}
CK_TILE_DEVICE static constexpr auto NumCSubTile() {
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});//32
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});//2
return integer_divide_ceil(kMPerBlock, WaveNumM * MPerXDL);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
......@@ -180,26 +192,6 @@ struct GemmPipelineAGmemBGmemCRegV1
// global read 0
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbb\n");
// constexpr auto span_2d2 = decltype(b_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
////////////// LDS desc, window & register /////////////////
// AB LDS desc
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
......@@ -363,18 +355,16 @@ struct GemmPipelineAGmemBGmemCRegV1
// 2
{
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
// if (threadIdx.x == 64) {
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f, %f; %f, %f. ", type_convert<float>(a_block_tile0(i_j_idx)), type_convert<float>(b_block_tile0(i_j_idx)), type_convert<float>(a_block_tile1(i_j_idx)), type_convert<float>(b_block_tile1(i_j_idx)));
// });
// printf("\n");
// });
// }
}
}
/// cccccccccc
// constexpr auto c_lds_block_desc = Policy::template MakeCLdsBlockDescriptor<Problem>();
// auto c_lds_block = make_tensor_view<address_space_enum::lds>(reinterpret_cast<CDataType*>(p_smem), c_lds_block_desc);
// auto c_lds_window0 = make_tile_window(c_lds_block, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0, 0});
// store_tile(c_lds_window0, c_block_tile);
// block_sync_lds();
return c_block_tile;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment