Commit e889d086 authored by feifei14119's avatar feifei14119
Browse files

save 51

parent e076a320
......@@ -12,6 +12,7 @@
#include "ck_tile/host.hpp"
#include "flatmm_basic.hpp"
#if 1
template <typename ALayout, typename BLayout, typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s)
{
......@@ -117,6 +118,129 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
return ave_time;
}
#else
template <typename ALayout, typename BLayout, typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
// This part comes from the Codegen
/*constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;*/
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenFlatmmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
#if FEIFEI_DEBUG
/*using BlockFlatmmStruct = ck_tile::remove_cvref_t<decltype(CodegenFlatmmPolicy::template GetBlockFlatmm<CodegenPipelineProblem>())>;
auto block_flatmm = BlockFlatmmStruct(); // struct BlockFlatmmASmemBSmemCRegV1
//auto ADramTileDistr = CodegenFlatmmPolicy::template MakeADramTileDistribution<CodegenPipelineProblem>();
auto kernel = Kernel{};
using SplitKBatchOffset = typename Kernel::SplitKBatchOffset;
SplitKBatchOffset splitk_batch_offset(args);
auto gemm_tensor_views_tuple = Kernel::template MakeGemmTensorViews<ck_tile::memory_operation_enum::set>(
args.a_ptr,
args.b_shuffle_ptr,
args.c_ptr,
kargs, splitk_batch_offset);*/
printf("[FEIFEI] --- flatmm_calc() ---\n");
printf("[FEIFEI] BlockPerCu = %d\n", static_cast<int>(kBlockPerCu));
printf("[FEIFEI] BlockTile M = %d\n", static_cast<int>(M_Tile));
printf("[FEIFEI] BlockTile N = %d\n", static_cast<int>(N_Tile));
printf("[FEIFEI] BlockTile K = %d\n", static_cast<int>(K_Tile));
printf("[FEIFEI] WavePerBlock M = %d\n", static_cast<int>(M_Warp));
printf("[FEIFEI] WavePerBlock N = %d\n", static_cast<int>(N_Warp));
printf("[FEIFEI] WavePerBlock K = %d\n", static_cast<int>(K_Warp));
printf("[FEIFEI] WaveTile M = %d\n", static_cast<int>(M_Warp_Tile));
printf("[FEIFEI] WaveTile N = %d\n", static_cast<int>(N_Warp_Tile));
printf("[FEIFEI] WaveTile K = %d\n", static_cast<int>(K_Warp_Tile));
printf("[FEIFEI] grids = [%d, %d, %d]\n", grids.x, grids.y, grids.z);
printf("[FEIFEI] blocks = [%d, %d, %d]\n", blocks.x, blocks.y, blocks.z);
#endif
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
#endif
#include "run_flatmm_example.inc"
......
......@@ -80,12 +80,12 @@ auto create_args(int argc, char* argv[])
.insert("n", "128", "n dimension") // 128, 4096
.insert("k", "64", "k dimension") // 64, 2048
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
......
......@@ -415,8 +415,8 @@ int run_flatmm_example_with_layouts(int argc,
// b_shuffle
{
std::ofstream file("ff_b_shuffle_host.txt");
int X = static_cast<int>(K);
int Y = static_cast<int>(N);
int X = 32 * 32;
int Y = static_cast<int>(N) * static_cast<int>(M) / X;
file << " [b_shuffle_host]: Row = " << Y << ", Col = " << X << std::endl;
for(int y = 0; y < Y; y++)
......
......@@ -24,10 +24,10 @@
// kernel
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
//#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
//#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
//#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
//#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
......@@ -55,14 +55,13 @@ struct BlockFlatmmASmemBSmemCRegV1
return c_block_tensor;
}
#if 1
#if 0
// C += A * B
// template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow>
template <typename ABlockWindow>
CK_TILE_DEVICE void operator()(const ABlockWindow& a_block_window
template <typename ABlockWindow, typename BBlockWindow>
CK_TILE_DEVICE void operator()(const ABlockWindow& a_block_window, const BBlockWindow& b_block_window
#if FEIFEI_DEBUG
,
const BDataType* b_ptr,
int* dbg_int,
float* dbg_fp32,
void* dbg_f168
......@@ -101,14 +100,12 @@ struct BlockFlatmmASmemBSmemCRegV1
*/
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
/*
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
*/
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
......@@ -117,11 +114,11 @@ struct BlockFlatmmASmemBSmemCRegV1
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
// constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
// constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
......@@ -134,6 +131,7 @@ struct BlockFlatmmASmemBSmemCRegV1
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
......@@ -151,7 +149,29 @@ struct BlockFlatmmASmemBSmemCRegV1
// Warp loop in block:
constexpr index_t kIter = 0;
constexpr index_t mIter = 0;
const auto a_warp_tensor = load_tile(a_warp_windows(number<mIter>{})(number<kIter>{}));
const auto a_warp_tensor = load_tile(a_warp_window_tmp);
#if FEIFEI_DEBUG
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[BLOCK ] WG::kM = %d, WG::kM = %d, WG::kK = %d, WG::kKPerThread = %d\n", WG::kM, WG::kN, WG::kK, WG::kKPerThread);
printf("[BLOCK ] MIterPerWarp = %d, NIterPerWarp = %d, KIterPerWarp = %d\n", MIterPerWarp, NIterPerWarp, KIterPerWarp);
}
// debug A lds read
int warp_tile_size_per_thread = a_warp_tensor.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[BLOCK ] warp_tile_size_per_thread = %d\n", warp_tile_size_per_thread);
}
for(auto i = 0; i < warp_tile_size_per_thread; i++)
{
dbg_f16[gid * DEBUG_CNT + i] = a_warp_tensor.get_thread_buffer()[i];
}
return ;
#endif
#if 1
// feifei TODO: Implement gemm here
......
......@@ -141,7 +141,7 @@ struct FlatmmKernel
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const FlatmmKernelArgs& kargs,
CK_TILE_DEVICE SplitKBatchOffset(const FlatmmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
......@@ -175,7 +175,42 @@ struct FlatmmKernel
splitted_k = kargs.K - KRead * (kargs.KBatch - 1);
}
}
#if FEIFEI_DEBUG
CK_TILE_HOST SplitKBatchOffset(const FlatmmHostArgs& hargs,
const std::size_t k_id = 0)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = hargs.k_batch * K1;
const index_t KRead = (hargs.K + K_t - 1) / K_t * K1;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * KRead;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * KRead * hargs.stride_A;
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * hargs.stride_B;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = k_id * KRead;
}
if(k_id < static_cast<uint32_t>(hargs.k_batch - 1))
{
splitted_k = KRead;
}
else
{
splitted_k = hargs.K - KRead * (hargs.k_batch - 1);
}
}
#endif
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t splitted_k; // problem K after splitted
......@@ -362,6 +397,9 @@ struct FlatmmKernel
return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view);
}
#if 1
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
......@@ -446,6 +484,118 @@ struct FlatmmKernel
return make_tuple(a_block_window, b_block_window, c_block_window);
}
#else
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
}();
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
}();
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<FlatmmPipeline::kPadM, false>{});
}
}();
return make_tuple(a_pad_view, b_pad_view, c_pad_view);
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& b_pad_view = views.at(I1);
const auto& c_pad_view = views.at(I2);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& b_block_window = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(a_block_window, b_block_window, c_block_window);
}
#endif
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
......@@ -477,20 +627,29 @@ struct FlatmmKernel
#endif
)
{
#if FEIFEI_DEBUG
uint32_t tidx = threadIdx.x;
uint32_t tidy = threadIdx.y;
uint32_t bidx = blockIdx.x;
uint32_t bidy = blockIdx.y;
uint32_t bdmx = blockDim.x;
uint32_t bdmy = blockDim.y;
uint32_t gdmx = gridDim.x;
uint32_t gdmy = gridDim.y;
uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy + (bdmx * bdmy) * bidx + bdmx * tidy + tidx;
half_t* dbg_f16 = static_cast<half_t*>(kargs.dbg_f168_ptr);
#endif
// Create Flatmm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_shuffle_ptr, c_ptr, kargs, splitk_batch_offset);
// Debug origin layout
// const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
// a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
const auto& gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
#if FEIFEI_DEBUG
////////////////////////////////////////////////////////
const auto& a_gemm_tensor_views = gemm_tensor_views_tuple.at(I0); // tensor_view
......@@ -533,39 +692,51 @@ struct FlatmmKernel
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_shuffle_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
}();
const auto& b_flat_pad_view = [&]() {
return pad_tensor_view(b_flat_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}();
const auto& b_flat_block_window = make_tile_window(
b_flat_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{block_idx_n, 0});
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window,
b_block_window,
b_flat_block_window,
num_loop,
smem_ptr
#if FEIFEI_DEBUG
,
b_ptr,
b_block_window,
dbg_int,
dbg_fp32,
dbg_f168
#endif
);
// feifei TODO: Un-comment bellow once pipeline() is implemented
#if 0
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
/*auto c_block_window = gemm_tile_windows.at(I2);
constexpr bool is_output_c_reg_transposed =
EpiloguePipeline::IsOutputTransposed() != FlatmmPipeline::IsTransposeC();
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) ||
(FlatmmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed))
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] C = %.3f\n", type_convert<float>(c_block_tile.get_thread_buffer()[0]));
}
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile);
}
#endif
c_block_window, c_block_tile, smem_ptr);*/
}
CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const
......
......@@ -14,6 +14,10 @@ namespace ck_tile {
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy> // feifei TODO: add default policy
struct FlatmmPipelineAGmemBGmemCRegV1
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
......@@ -62,18 +66,22 @@ struct FlatmmPipelineAGmemBGmemCRegV1
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
typename BElementFunction
#if FEIFEI_DEBUG
, typename BDramBlockWindowTmp
#endif
>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem
#if FEIFEI_DEBUG
,
const BDataType* b_ptr,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
int* dbg_int,
float* dbg_fp32,
void* dbg_f168
......@@ -111,63 +119,107 @@ struct FlatmmPipelineAGmemBGmemCRegV1
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BFlatBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] kMPerBlock = %d, winN = %d\n", kMPerBlock,
static_cast<int>(ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}]));
printf("[PIPELN] kNPerBlock = %d, winN = %d\n", kNPerBlock,
static_cast<int>(BFlatBlockWindowTmp{}.get_window_lengths()[number<0>{}]));
printf("[PIPELN] kNPerBlock = %d, winN = %d\n", kNPerBlock,
static_cast<int>(BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}]));
printf("[PIPELN] kKPerBlock = %d, winN = %d\n", kKPerBlock,
static_cast<int>(ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}]));
}
#if 1
// feifei TODO: Implement gemm here
// Get block flatmm
auto block_flatmm = BlockFlatmm(); // struct BlockFlatmmASmemBSmemCRegV1
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
auto a_copy_dram_window = // tile_window_with_static_distribution
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B DRAM tile window for load
auto b_copy_dram_window = // tile_window_with_static_distribution
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeBDramTileDistribution<Problem>());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B flat DRAM window for load
auto b_flat_distribution = PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<kNPerBlock>{}, number<BlockSize>{} * 4),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
// Prefetch -----------------------------------------------------------
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
auto b_flat_tile = load_tile(b_flat_dram_window);
#if FEIFEI_DEBUG // debug A global load
int a_dim = a_block_tile.get_num_of_dimension();
int a_sz = a_block_tile.get_thread_buffer_size();
#if FEIFEI_DEBUG
// debug A global load
int a_block_tile_size_per_thread = a_block_tile.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] a_dim = %d, a_sz = %d\n", a_dim, a_sz);
printf("[PIPELN] a_block_tile_size_per_thread = %d\n", a_block_tile_size_per_thread);
}
for(auto i = 0; i < a_sz; i++)
for(auto i = 0; i < a_block_tile_size_per_thread; i++)
{
dbg_f16[gid * DEBUG_CNT + i] = a_block_tile.get_thread_buffer()[i];
}
// debug B global load
int b_block_tile_size_per_thread = b_block_tile.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] b_block_tile_size_per_thread = %d\n", b_block_tile_size_per_thread);
}
for(auto i = 0; i < b_block_tile_size_per_thread; i++)
{
//dbg_f16[gid * DEBUG_CNT + i] = b_block_tile.get_thread_buffer()[i];
}
// debug flat B global load
int b_flat_tile_size_per_thread = b_flat_tile.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] b_flat_tile_size_per_thread = %d\n", b_flat_tile_size_per_thread);
}
for(auto i = 0; i < b_flat_tile_size_per_thread; i++)
{
//dbg_f16[gid * DEBUG_CNT + i + b_block_tile_size_per_thread + 4] = b_flat_tile.get_thread_buffer()[i];
}
return nullptr;
#endif
#if 0
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window( // tile_window_with_static_lengths
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window( // tile_window_with_static_lengths
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// LDS write 0
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
......@@ -183,12 +235,26 @@ struct FlatmmPipelineAGmemBGmemCRegV1
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
// B tile in LDS
constexpr index_t a_lds_block_space_size_aligned = integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * 16;
BDataType* p_b_lds = static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = PipelinePolicy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// B LDS tile window for store
auto b_copy_lds_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Loop ---------------------------------------------------------------
// Do flatmm
block_flatmm(a_lds_gemm_window
block_sync_lds();
block_flatmm(a_lds_gemm_window, b_lds_gemm_window
#if FEIFEI_DEBUG
,
b_ptr,
dbg_int,
dbg_fp32,
dbg_f168
......@@ -198,6 +264,157 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// Tail ---------------------------------------------------------------
return nullptr;
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// A tile in LDS
/*ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc =
PipelinePolicy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile window for store
auto b_copy_lds_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
auto block_gemm = BlockFlatmm();
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
// prefetch
// global read 0
//auto a_block_tile = load_tile(a_copy_dram_window);
//auto b_block_tile = load_tile(b_copy_dram_window);
{
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
PipelinePolicy::template MakeShuffledARegBlockDescriptor<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
// LDS write 0
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
PipelinePolicy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
else
{
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
}
}
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
block_sync_lds();
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// LDS write i + 1
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
PipelinePolicy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
}
else
{
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
iCounter--;
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
int c_block_tile_size_per_thread = c_block_tile.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] c_block_tile_size_per_thread = %d\n", c_block_tile_size_per_thread);
}
for(auto i = 0; i < c_block_tile_size_per_thread; i++)
{
//dbg_fp32[gid * DEBUG_CNT + i] = c_block_tile.get_thread_buffer()[i];
dbg_fp32[gid * DEBUG_CNT + i] = 3.12f;
c_block_tile.get_thread_buffer()[i] = 1.23f;
}
return c_block_tile;*/
////////////////////////////////////////////////////////////////////////////////////////////////////
#else
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
......@@ -352,14 +569,18 @@ struct FlatmmPipelineAGmemBGmemCRegV1
#endif
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp
#if FEIFEI_DEBUG
, typename BDramBlockWindowTmp
#endif
>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem
#if FEIFEI_DEBUG
,
const BDataType* b_ptr,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
int* dbg_int,
float* dbg_fp32,
void* dbg_f168
......@@ -369,13 +590,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1
return operator()(
a_dram_block_window_tmp,
[](const ADataType & a) { return a; },
b_dram_block_window_tmp,
b_flat_dram_block_window_tmp,
[](const BDataType & b) { return b; },
num_loop,
p_smem
#if FEIFEI_DEBUG
,
b_ptr,
b_dram_block_window_tmp,
dbg_int,
dbg_fp32,
dbg_f168
......
......@@ -227,15 +227,24 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); // dwordx4 load A elem cnt
constexpr index_t K0 = KPerBlock / K1; // threads cnt in K dim
constexpr index_t M2 = get_warp_size() / K0; // threads cnt in M dim (per wave)
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
constexpr index_t M1 = BlockSize / get_warp_size(); // wave cnt in M dim (per block)
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
constexpr index_t M0 = MPerBlock / (M2 * M1); // load repeat times in M dim
#if FEIFEI_DEBUG
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] MakeADramTileDistribution():\n");
printf("[PIPELN] MPerBlock = %d, KPerBlock = %d, AperBlock = %d\n", MPerBlock, KPerBlock, MPerBlock*KPerBlock);
printf("[PIPELN] BlockSize = %d, warp_size = %d, VectorLoadSize = %d\n", BlockSize, get_warp_size(), Problem::VectorLoadSize);
printf("[PIPELN] K1 = %d, K0 = %d, M2 = %d, M1 = %d, M0 = %d\n", K1, K0, M2, M1, M0);
}
#endif
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
......@@ -310,18 +319,25 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt
constexpr index_t K0 = KPerBlock / K1; // threads cnt in K dim
constexpr index_t N2 = get_warp_size() / K0; // threads cnt in N dim (per wave)
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
constexpr index_t N1 = BlockSize / get_warp_size();
constexpr index_t N1 = BlockSize / get_warp_size(); // wave cnt in N dim (per block)
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
constexpr index_t N0 = NPerBlock / (N2 * N1); // load repeat times in N dim
#if FEIFEI_DEBUG
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] MakeBDramTileDistribution():\n");
printf("[PIPELN] NPerBlock = %d, KPerBlock = %d, BperBlock = %d\n", NPerBlock, KPerBlock, NPerBlock*KPerBlock);
printf("[PIPELN] BlockSize = %d, warp_size = %d, VectorLoadSize = %d\n", BlockSize, get_warp_size(), Problem::VectorLoadSize);
printf("[PIPELN] K1 = %d, K0 = %d, N2 = %d, N1 = %d, N0 = %d\n", K1, K0, N2, N1, N0);
}
#endif
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
......@@ -347,6 +363,92 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t KLoad = Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt
constexpr index_t KThdInBlk = 64;
constexpr index_t KBlkInTile = 1;
constexpr index_t KRepeat = 1;
constexpr index_t NLoad = 1; // dwordx4 load B elem cnt
constexpr index_t NThdInBlk = 1;
constexpr index_t NBlkInTile = 4;
constexpr index_t NRepeat = 1;
#if FEIFEI_DEBUG
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] MakeBFlatDramTileDistribution():\n");
printf("[PIPELN] NPerBlock = %d, KPerBlock = %d, BperBlock = %d\n",
NPerBlock,
KPerBlock,
NPerBlock * KPerBlock);
printf("[PIPELN] BlockSize = %d, warp_size = %d, VectorLoadSize = %d\n",
BlockSize,
get_warp_size(),
Problem::VectorLoadSize);
printf("[PIPELN] NRepeat = %d, NBlkInTile = %d, NThdInBlk = %d, NLoad = %d\n", NRepeat, NBlkInTile, NThdInBlk, NLoad);
printf("[PIPELN] KRepeat = %d, KBlkInTile = %d, KThdInBlk = %d, KLoad = %d\n", KRepeat, KBlkInTile, KThdInBlk, KLoad);
}
#endif
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NRepeat, NBlkInTile, NThdInBlk, NLoad>,
sequence<KRepeat, KBlkInTile, KThdInBlk, KLoad>>, // first dim
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<2, 2>,
sequence<0, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment