Commit e889d086 authored by feifei14119's avatar feifei14119
Browse files

save 51

parent e076a320
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "flatmm_basic.hpp" #include "flatmm_basic.hpp"
#if 1
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s)
{ {
...@@ -117,6 +118,129 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con ...@@ -117,6 +118,129 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
return ave_time; return ave_time;
} }
#else
template <typename ALayout, typename BLayout, typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
// This part comes from the Codegen
/*constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;*/
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenFlatmmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
#if FEIFEI_DEBUG
/*using BlockFlatmmStruct = ck_tile::remove_cvref_t<decltype(CodegenFlatmmPolicy::template GetBlockFlatmm<CodegenPipelineProblem>())>;
auto block_flatmm = BlockFlatmmStruct(); // struct BlockFlatmmASmemBSmemCRegV1
//auto ADramTileDistr = CodegenFlatmmPolicy::template MakeADramTileDistribution<CodegenPipelineProblem>();
auto kernel = Kernel{};
using SplitKBatchOffset = typename Kernel::SplitKBatchOffset;
SplitKBatchOffset splitk_batch_offset(args);
auto gemm_tensor_views_tuple = Kernel::template MakeGemmTensorViews<ck_tile::memory_operation_enum::set>(
args.a_ptr,
args.b_shuffle_ptr,
args.c_ptr,
kargs, splitk_batch_offset);*/
printf("[FEIFEI] --- flatmm_calc() ---\n");
printf("[FEIFEI] BlockPerCu = %d\n", static_cast<int>(kBlockPerCu));
printf("[FEIFEI] BlockTile M = %d\n", static_cast<int>(M_Tile));
printf("[FEIFEI] BlockTile N = %d\n", static_cast<int>(N_Tile));
printf("[FEIFEI] BlockTile K = %d\n", static_cast<int>(K_Tile));
printf("[FEIFEI] WavePerBlock M = %d\n", static_cast<int>(M_Warp));
printf("[FEIFEI] WavePerBlock N = %d\n", static_cast<int>(N_Warp));
printf("[FEIFEI] WavePerBlock K = %d\n", static_cast<int>(K_Warp));
printf("[FEIFEI] WaveTile M = %d\n", static_cast<int>(M_Warp_Tile));
printf("[FEIFEI] WaveTile N = %d\n", static_cast<int>(N_Warp_Tile));
printf("[FEIFEI] WaveTile K = %d\n", static_cast<int>(K_Warp_Tile));
printf("[FEIFEI] grids = [%d, %d, %d]\n", grids.x, grids.y, grids.z);
printf("[FEIFEI] blocks = [%d, %d, %d]\n", blocks.x, blocks.y, blocks.z);
#endif
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
#endif
#include "run_flatmm_example.inc" #include "run_flatmm_example.inc"
......
...@@ -80,12 +80,12 @@ auto create_args(int argc, char* argv[]) ...@@ -80,12 +80,12 @@ auto create_args(int argc, char* argv[])
.insert("n", "128", "n dimension") // 128, 4096 .insert("n", "128", "n dimension") // 128, 4096
.insert("k", "64", "k dimension") // 64, 2048 .insert("k", "64", "k dimension") // 64, 2048
.insert("a_layout", "R", "A tensor data layout - Row by default") .insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride") .insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride") .insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride") .insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
......
...@@ -415,8 +415,8 @@ int run_flatmm_example_with_layouts(int argc, ...@@ -415,8 +415,8 @@ int run_flatmm_example_with_layouts(int argc,
// b_shuffle // b_shuffle
{ {
std::ofstream file("ff_b_shuffle_host.txt"); std::ofstream file("ff_b_shuffle_host.txt");
int X = static_cast<int>(K); int X = 32 * 32;
int Y = static_cast<int>(N); int Y = static_cast<int>(N) * static_cast<int>(M) / X;
file << " [b_shuffle_host]: Row = " << Y << ", Col = " << X << std::endl; file << " [b_shuffle_host]: Row = " << Y << ", Col = " << X << std::endl;
for(int y = 0; y < Y; y++) for(int y = 0; y < Y; y++)
......
...@@ -24,10 +24,10 @@ ...@@ -24,10 +24,10 @@
// kernel // kernel
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp" #include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" //#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp" //#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp" //#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" //#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp" #include "ck_tile/ops/common/utils.hpp"
...@@ -55,14 +55,13 @@ struct BlockFlatmmASmemBSmemCRegV1 ...@@ -55,14 +55,13 @@ struct BlockFlatmmASmemBSmemCRegV1
return c_block_tensor; return c_block_tensor;
} }
#if 1 #if 0
// C += A * B // C += A * B
// template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow> // template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow>
template <typename ABlockWindow> template <typename ABlockWindow, typename BBlockWindow>
CK_TILE_DEVICE void operator()(const ABlockWindow& a_block_window CK_TILE_DEVICE void operator()(const ABlockWindow& a_block_window, const BBlockWindow& b_block_window
#if FEIFEI_DEBUG #if FEIFEI_DEBUG
, ,
const BDataType* b_ptr,
int* dbg_int, int* dbg_int,
float* dbg_fp32, float* dbg_fp32,
void* dbg_f168 void* dbg_f168
...@@ -101,14 +100,12 @@ struct BlockFlatmmASmemBSmemCRegV1 ...@@ -101,14 +100,12 @@ struct BlockFlatmmASmemBSmemCRegV1
*/ */
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}]; constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
/*
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK, KPerBlock == BlockGemmShape::kK,
"wrong!"); "wrong!");
*/
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>; using WG = remove_cvref_t<decltype(config.template at<0>())>;
...@@ -117,11 +114,11 @@ struct BlockFlatmmASmemBSmemCRegV1 ...@@ -117,11 +114,11 @@ struct BlockFlatmmASmemBSmemCRegV1
constexpr index_t NWarp = config.template at<2>(); constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
// constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK; constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
// constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp; const index_t iMWarp = get_warp_id() / NWarp;
...@@ -133,6 +130,7 @@ struct BlockFlatmmASmemBSmemCRegV1 ...@@ -133,6 +130,7 @@ struct BlockFlatmmASmemBSmemCRegV1
make_tuple(number<WG::kM>{}, number<WG::kK>{}), make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
statically_indexed_array< statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>, statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
...@@ -151,7 +149,29 @@ struct BlockFlatmmASmemBSmemCRegV1 ...@@ -151,7 +149,29 @@ struct BlockFlatmmASmemBSmemCRegV1
// Warp loop in block: // Warp loop in block:
constexpr index_t kIter = 0; constexpr index_t kIter = 0;
constexpr index_t mIter = 0; constexpr index_t mIter = 0;
const auto a_warp_tensor = load_tile(a_warp_windows(number<mIter>{})(number<kIter>{})); const auto a_warp_tensor = load_tile(a_warp_window_tmp);
#if FEIFEI_DEBUG
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[BLOCK ] WG::kM = %d, WG::kM = %d, WG::kK = %d, WG::kKPerThread = %d\n", WG::kM, WG::kN, WG::kK, WG::kKPerThread);
printf("[BLOCK ] MIterPerWarp = %d, NIterPerWarp = %d, KIterPerWarp = %d\n", MIterPerWarp, NIterPerWarp, KIterPerWarp);
}
// debug A lds read
int warp_tile_size_per_thread = a_warp_tensor.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[BLOCK ] warp_tile_size_per_thread = %d\n", warp_tile_size_per_thread);
}
for(auto i = 0; i < warp_tile_size_per_thread; i++)
{
dbg_f16[gid * DEBUG_CNT + i] = a_warp_tensor.get_thread_buffer()[i];
}
return ;
#endif
#if 1 #if 1
// feifei TODO: Implement gemm here // feifei TODO: Implement gemm here
......
...@@ -141,7 +141,7 @@ struct FlatmmKernel ...@@ -141,7 +141,7 @@ struct FlatmmKernel
struct SplitKBatchOffset struct SplitKBatchOffset
{ {
__device__ SplitKBatchOffset(const FlatmmKernelArgs& kargs, CK_TILE_DEVICE SplitKBatchOffset(const FlatmmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z) const std::size_t k_id = blockIdx.z)
{ {
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
...@@ -175,7 +175,42 @@ struct FlatmmKernel ...@@ -175,7 +175,42 @@ struct FlatmmKernel
splitted_k = kargs.K - KRead * (kargs.KBatch - 1); splitted_k = kargs.K - KRead * (kargs.KBatch - 1);
} }
} }
#if FEIFEI_DEBUG
CK_TILE_HOST SplitKBatchOffset(const FlatmmHostArgs& hargs,
const std::size_t k_id = 0)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = hargs.k_batch * K1;
const index_t KRead = (hargs.K + K_t - 1) / K_t * K1;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * KRead;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * KRead * hargs.stride_A;
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * hargs.stride_B;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = k_id * KRead;
}
if(k_id < static_cast<uint32_t>(hargs.k_batch - 1))
{
splitted_k = KRead;
}
else
{
splitted_k = hargs.K - KRead * (hargs.k_batch - 1);
}
}
#endif
index_t a_k_split_offset; index_t a_k_split_offset;
index_t b_k_split_offset; index_t b_k_split_offset;
index_t splitted_k; // problem K after splitted index_t splitted_k; // problem K after splitted
...@@ -362,6 +397,9 @@ struct FlatmmKernel ...@@ -362,6 +397,9 @@ struct FlatmmKernel
return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view); return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view);
} }
#if 1
template <typename TensorView> template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{ {
...@@ -446,6 +484,118 @@ struct FlatmmKernel ...@@ -446,6 +484,118 @@ struct FlatmmKernel
return make_tuple(a_block_window, b_block_window, c_block_window); return make_tuple(a_block_window, b_block_window, c_block_window);
} }
#else
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
}();
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
}();
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<FlatmmPipeline::kPadM, false>{});
}
}();
return make_tuple(a_pad_view, b_pad_view, c_pad_view);
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& b_pad_view = views.at(I1);
const auto& c_pad_view = views.at(I2);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& b_block_window = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(a_block_window, b_block_window, c_block_window);
}
#endif
/** /**
* @brief Runs single GEMM problem cooperatively by whole workgroup. * @brief Runs single GEMM problem cooperatively by whole workgroup.
...@@ -477,20 +627,29 @@ struct FlatmmKernel ...@@ -477,20 +627,29 @@ struct FlatmmKernel
#endif #endif
) )
{ {
#if FEIFEI_DEBUG
uint32_t tidx = threadIdx.x;
uint32_t tidy = threadIdx.y;
uint32_t bidx = blockIdx.x;
uint32_t bidy = blockIdx.y;
uint32_t bdmx = blockDim.x;
uint32_t bdmy = blockDim.y;
uint32_t gdmx = gridDim.x;
uint32_t gdmy = gridDim.y;
uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy + (bdmx * bdmy) * bidx + bdmx * tidy + tidx;
half_t* dbg_f16 = static_cast<half_t*>(kargs.dbg_f168_ptr);
#endif
// Create Flatmm tensor views, pad views and tile windows // Create Flatmm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>( const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_shuffle_ptr, c_ptr, kargs, splitk_batch_offset); a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
// Debug origin layout
// const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
// a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
const auto& gemm_tile_windows = const auto& gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
#if FEIFEI_DEBUG #if FEIFEI_DEBUG
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
const auto& a_gemm_tensor_views = gemm_tensor_views_tuple.at(I0); // tensor_view const auto& a_gemm_tensor_views = gemm_tensor_views_tuple.at(I0); // tensor_view
...@@ -533,39 +692,51 @@ struct FlatmmKernel ...@@ -533,39 +692,51 @@ struct FlatmmKernel
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_shuffle_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
}();
const auto& b_flat_pad_view = [&]() {
return pad_tensor_view(b_flat_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}();
const auto& b_flat_block_window = make_tile_window(
b_flat_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{block_idx_n, 0});
// Run GEMM cooperatively by whole workgroup. // Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0); const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1); const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window, const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window,
b_block_window, b_flat_block_window,
num_loop, num_loop,
smem_ptr smem_ptr
#if FEIFEI_DEBUG #if FEIFEI_DEBUG
, ,
b_ptr, b_block_window,
dbg_int, dbg_int,
dbg_fp32, dbg_fp32,
dbg_f168 dbg_f168
#endif #endif
); );
// feifei TODO: Un-comment bellow once pipeline() is implemented
#if 0
// Run Epilogue Pipeline // Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2); /*auto c_block_window = gemm_tile_windows.at(I2);
constexpr bool is_output_c_reg_transposed = if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
EpiloguePipeline::IsOutputTransposed() != FlatmmPipeline::IsTransposeC();
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) ||
(FlatmmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed))
{ {
EpiloguePipeline{} printf("[PIPELN] C = %.3f\n", type_convert<float>(c_block_tile.get_thread_buffer()[0]));
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile);
} }
#endif EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile, smem_ptr);*/
} }
CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const
......
...@@ -14,6 +14,10 @@ namespace ck_tile { ...@@ -14,6 +14,10 @@ namespace ck_tile {
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy> // feifei TODO: add default policy template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy> // feifei TODO: add default policy
struct FlatmmPipelineAGmemBGmemCRegV1 struct FlatmmPipelineAGmemBGmemCRegV1
{ {
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
...@@ -62,18 +66,22 @@ struct FlatmmPipelineAGmemBGmemCRegV1 ...@@ -62,18 +66,22 @@ struct FlatmmPipelineAGmemBGmemCRegV1
} }
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BFlatBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
typename BElementFunction> typename BElementFunction
#if FEIFEI_DEBUG
, typename BDramBlockWindowTmp
#endif
>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func, const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const BElementFunction& b_element_func, const BElementFunction& b_element_func,
index_t num_loop, index_t num_loop,
void* p_smem void* p_smem
#if FEIFEI_DEBUG #if FEIFEI_DEBUG
, ,
const BDataType* b_ptr, const BDramBlockWindowTmp& b_dram_block_window_tmp,
int* dbg_int, int* dbg_int,
float* dbg_fp32, float* dbg_fp32,
void* dbg_f168 void* dbg_f168
...@@ -111,63 +119,107 @@ struct FlatmmPipelineAGmemBGmemCRegV1 ...@@ -111,63 +119,107 @@ struct FlatmmPipelineAGmemBGmemCRegV1
"wrong!"); "wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kNPerBlock == BFlatBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] kMPerBlock = %d, winN = %d\n", kMPerBlock,
static_cast<int>(ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}]));
printf("[PIPELN] kNPerBlock = %d, winN = %d\n", kNPerBlock,
static_cast<int>(BFlatBlockWindowTmp{}.get_window_lengths()[number<0>{}]));
printf("[PIPELN] kNPerBlock = %d, winN = %d\n", kNPerBlock,
static_cast<int>(BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}]));
printf("[PIPELN] kKPerBlock = %d, winN = %d\n", kKPerBlock,
static_cast<int>(ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}]));
}
#if 1 #if 1
// feifei TODO: Implement gemm here // feifei TODO: Implement gemm here
// Get block flatmm // Get block flatmm
auto block_flatmm = BlockFlatmm(); // struct BlockFlatmmASmemBSmemCRegV1 auto block_flatmm = BlockFlatmm(); // struct BlockFlatmmASmemBSmemCRegV1
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
// A DRAM tile window for load // A DRAM tile window for load
auto a_copy_dram_window = auto a_copy_dram_window = // tile_window_with_static_distribution
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(), a_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeADramTileDistribution<Problem>()); PipelinePolicy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // B DRAM tile window for load
auto a_copy_lds_window = make_tile_window( auto b_copy_dram_window = // tile_window_with_static_distribution
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}); make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeBDramTileDistribution<Problem>());
// A LDS tile for block GEMM // B flat DRAM window for load
auto a_lds_gemm_window = make_tile_window( auto b_flat_distribution = PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}); auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<kNPerBlock>{}, number<BlockSize>{} * 4),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
// Prefetch ----------------------------------------------------------- // Prefetch -----------------------------------------------------------
// global read 0 // global read 0
auto a_block_tile = load_tile(a_copy_dram_window); auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
auto b_flat_tile = load_tile(b_flat_dram_window);
#if FEIFEI_DEBUG // debug A global load #if FEIFEI_DEBUG
int a_dim = a_block_tile.get_num_of_dimension(); // debug A global load
int a_sz = a_block_tile.get_thread_buffer_size(); int a_block_tile_size_per_thread = a_block_tile.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{ {
printf("[PIPELN] a_dim = %d, a_sz = %d\n", a_dim, a_sz); printf("[PIPELN] a_block_tile_size_per_thread = %d\n", a_block_tile_size_per_thread);
} }
for(auto i = 0; i < a_sz; i++) for(auto i = 0; i < a_block_tile_size_per_thread; i++)
{ {
dbg_f16[gid * DEBUG_CNT + i] = a_block_tile.get_thread_buffer()[i]; dbg_f16[gid * DEBUG_CNT + i] = a_block_tile.get_thread_buffer()[i];
} }
// debug B global load
int b_block_tile_size_per_thread = b_block_tile.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] b_block_tile_size_per_thread = %d\n", b_block_tile_size_per_thread);
}
for(auto i = 0; i < b_block_tile_size_per_thread; i++)
{
//dbg_f16[gid * DEBUG_CNT + i] = b_block_tile.get_thread_buffer()[i];
}
// debug flat B global load
int b_flat_tile_size_per_thread = b_flat_tile.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] b_flat_tile_size_per_thread = %d\n", b_flat_tile_size_per_thread);
}
for(auto i = 0; i < b_flat_tile_size_per_thread; i++)
{
//dbg_f16[gid * DEBUG_CNT + i + b_block_tile_size_per_thread + 4] = b_flat_tile.get_thread_buffer()[i];
}
return nullptr; return nullptr;
#endif #endif
#if 0
// move to 1 // move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock}); move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window( // tile_window_with_static_lengths
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window( // tile_window_with_static_lengths
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// LDS write 0 // LDS write 0
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
...@@ -183,12 +235,26 @@ struct FlatmmPipelineAGmemBGmemCRegV1 ...@@ -183,12 +235,26 @@ struct FlatmmPipelineAGmemBGmemCRegV1
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
} }
// B tile in LDS
constexpr index_t a_lds_block_space_size_aligned = integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * 16;
BDataType* p_b_lds = static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = PipelinePolicy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// B LDS tile window for store
auto b_copy_lds_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Loop --------------------------------------------------------------- // Loop ---------------------------------------------------------------
// Do flatmm // Do flatmm
block_flatmm(a_lds_gemm_window block_sync_lds();
block_flatmm(a_lds_gemm_window, b_lds_gemm_window
#if FEIFEI_DEBUG #if FEIFEI_DEBUG
, ,
b_ptr,
dbg_int, dbg_int,
dbg_fp32, dbg_fp32,
dbg_f168 dbg_f168
...@@ -198,6 +264,157 @@ struct FlatmmPipelineAGmemBGmemCRegV1 ...@@ -198,6 +264,157 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// Tail --------------------------------------------------------------- // Tail ---------------------------------------------------------------
return nullptr; return nullptr;
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// A tile in LDS
/*ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc =
PipelinePolicy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile window for store
auto b_copy_lds_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
auto block_gemm = BlockFlatmm();
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
// prefetch
// global read 0
//auto a_block_tile = load_tile(a_copy_dram_window);
//auto b_block_tile = load_tile(b_copy_dram_window);
{
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
PipelinePolicy::template MakeShuffledARegBlockDescriptor<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
// LDS write 0
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
PipelinePolicy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
else
{
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
}
}
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
block_sync_lds();
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// LDS write i + 1
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
PipelinePolicy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
}
else
{
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
iCounter--;
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
int c_block_tile_size_per_thread = c_block_tile.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] c_block_tile_size_per_thread = %d\n", c_block_tile_size_per_thread);
}
for(auto i = 0; i < c_block_tile_size_per_thread; i++)
{
//dbg_fp32[gid * DEBUG_CNT + i] = c_block_tile.get_thread_buffer()[i];
dbg_fp32[gid * DEBUG_CNT + i] = 3.12f;
c_block_tile.get_thread_buffer()[i] = 1.23f;
}
return c_block_tile;*/
////////////////////////////////////////////////////////////////////////////////////////////////////
#else #else
// A tile in LDS // A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem); ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
...@@ -352,14 +569,18 @@ struct FlatmmPipelineAGmemBGmemCRegV1 ...@@ -352,14 +569,18 @@ struct FlatmmPipelineAGmemBGmemCRegV1
#endif #endif
} }
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp> template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp
#if FEIFEI_DEBUG
, typename BDramBlockWindowTmp
#endif
>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop, index_t num_loop,
void* p_smem void* p_smem
#if FEIFEI_DEBUG #if FEIFEI_DEBUG
, ,
const BDataType* b_ptr, const BDramBlockWindowTmp& b_dram_block_window_tmp,
int* dbg_int, int* dbg_int,
float* dbg_fp32, float* dbg_fp32,
void* dbg_f168 void* dbg_f168
...@@ -369,13 +590,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1 ...@@ -369,13 +590,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1
return operator()( return operator()(
a_dram_block_window_tmp, a_dram_block_window_tmp,
[](const ADataType & a) { return a; }, [](const ADataType & a) { return a; },
b_dram_block_window_tmp, b_flat_dram_block_window_tmp,
[](const BDataType & b) { return b; }, [](const BDataType & b) { return b; },
num_loop, num_loop,
p_smem p_smem
#if FEIFEI_DEBUG #if FEIFEI_DEBUG
, ,
b_ptr, b_dram_block_window_tmp,
dbg_int, dbg_int,
dbg_fp32, dbg_fp32,
dbg_f168 dbg_f168
......
...@@ -227,15 +227,24 @@ struct UniversalFlatmmPipelineAgBgCrPolicy ...@@ -227,15 +227,24 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
} }
else else
{ {
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); // dwordx4 load A elem cnt
constexpr index_t K0 = KPerBlock / K1; constexpr index_t K0 = KPerBlock / K1; // threads cnt in K dim
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t M2 = get_warp_size() / K0; // threads cnt in M dim (per wave)
if constexpr(get_warp_size() % (M2 * K0) == 0) if constexpr(get_warp_size() % (M2 * K0) == 0)
{ {
constexpr index_t M1 = BlockSize / get_warp_size(); constexpr index_t M1 = BlockSize / get_warp_size(); // wave cnt in M dim (per block)
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1); constexpr index_t M0 = MPerBlock / (M2 * M1); // load repeat times in M dim
#if FEIFEI_DEBUG
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] MakeADramTileDistribution():\n");
printf("[PIPELN] MPerBlock = %d, KPerBlock = %d, AperBlock = %d\n", MPerBlock, KPerBlock, MPerBlock*KPerBlock);
printf("[PIPELN] BlockSize = %d, warp_size = %d, VectorLoadSize = %d\n", BlockSize, get_warp_size(), Problem::VectorLoadSize);
printf("[PIPELN] K1 = %d, K0 = %d, M2 = %d, M1 = %d, M0 = %d\n", K1, K0, M2, M1, M0);
}
#endif
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
...@@ -310,18 +319,25 @@ struct UniversalFlatmmPipelineAgBgCrPolicy ...@@ -310,18 +319,25 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
} }
else else
{ {
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); constexpr index_t K0 = KPerBlock / K1; // threads cnt in K dim
constexpr index_t K0 = KPerBlock / K1; constexpr index_t N2 = get_warp_size() / K0; // threads cnt in N dim (per wave)
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks // coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0) if constexpr(get_warp_size() % (N2 * K0) == 0)
{ {
constexpr index_t N1 = BlockSize / get_warp_size(); constexpr index_t N1 = BlockSize / get_warp_size(); // wave cnt in N dim (per block)
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1); constexpr index_t N0 = NPerBlock / (N2 * N1); // load repeat times in N dim
#if FEIFEI_DEBUG
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] MakeBDramTileDistribution():\n");
printf("[PIPELN] NPerBlock = %d, KPerBlock = %d, BperBlock = %d\n", NPerBlock, KPerBlock, NPerBlock*KPerBlock);
printf("[PIPELN] BlockSize = %d, warp_size = %d, VectorLoadSize = %d\n", BlockSize, get_warp_size(), Problem::VectorLoadSize);
printf("[PIPELN] K1 = %d, K0 = %d, N2 = %d, N1 = %d, N0 = %d\n", K1, K0, N2, N1, N0);
}
#endif
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
...@@ -347,6 +363,92 @@ struct UniversalFlatmmPipelineAgBgCrPolicy ...@@ -347,6 +363,92 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
} }
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t KLoad = Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt
constexpr index_t KThdInBlk = 64;
constexpr index_t KBlkInTile = 1;
constexpr index_t KRepeat = 1;
constexpr index_t NLoad = 1; // dwordx4 load B elem cnt
constexpr index_t NThdInBlk = 1;
constexpr index_t NBlkInTile = 4;
constexpr index_t NRepeat = 1;
#if FEIFEI_DEBUG
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] MakeBFlatDramTileDistribution():\n");
printf("[PIPELN] NPerBlock = %d, KPerBlock = %d, BperBlock = %d\n",
NPerBlock,
KPerBlock,
NPerBlock * KPerBlock);
printf("[PIPELN] BlockSize = %d, warp_size = %d, VectorLoadSize = %d\n",
BlockSize,
get_warp_size(),
Problem::VectorLoadSize);
printf("[PIPELN] NRepeat = %d, NBlkInTile = %d, NThdInBlk = %d, NLoad = %d\n", NRepeat, NBlkInTile, NThdInBlk, NLoad);
printf("[PIPELN] KRepeat = %d, KBlkInTile = %d, KThdInBlk = %d, KLoad = %d\n", KRepeat, KBlkInTile, KThdInBlk, KLoad);
}
#endif
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NRepeat, NBlkInTile, NThdInBlk, NLoad>,
sequence<KRepeat, KBlkInTile, KThdInBlk, KLoad>>, // first dim
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<2, 2>,
sequence<0, 3>>{});
}
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment