Commit e076a320 authored by feifei14119's avatar feifei14119
Browse files

debug a

parent 839cf897
......@@ -5,5 +5,5 @@ list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef)
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -ggdb -g -O0 -v -save-temps)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DFEIFEI_DEBUG=1)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DFEIFEI_DEBUG=1 -DDEBUG_CNT=64)
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
......@@ -183,9 +183,9 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
#if FEIFEI_DEBUG
ck_tile::HostTensor<int> dbg_int({M * N * 64});
ck_tile::HostTensor<float> dbg_fp32({M * N * 64});
ck_tile::HostTensor<ADataType> dbg_f168({M * N * 64});
ck_tile::HostTensor<int> dbg_int({M * N * DEBUG_CNT});
ck_tile::HostTensor<float> dbg_fp32({M * N * DEBUG_CNT});
ck_tile::HostTensor<ADataType> dbg_f168({M * N * DEBUG_CNT});
ck_tile::DeviceMem dbg_int_buf(dbg_int.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbg_fp32_buf(dbg_fp32.get_element_space_size_in_bytes());
......@@ -362,7 +362,7 @@ int run_flatmm_example_with_layouts(int argc,
int GridDimY = 1;
int BlockDimX = 64;
int BlockDimY = 4;
int DbgCnt = 64;
int DbgCnt = DEBUG_CNT;
int BlockSize = BlockDimX * BlockDimY;
// a_host
{
......
......@@ -3,40 +3,9 @@
#pragma once
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
// #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
// #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
// #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
// #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
......@@ -45,10 +14,15 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
// block
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp"
// pipeline
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp"
// kernel
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename BlockPolicy_ = BlockFlatmmASmemBSmemCRegV1DefaultPolicy>
struct BlockFlatmmASmemBSmemCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
#if 1
// C += A * B
// template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow>
template <typename ABlockWindow>
CK_TILE_DEVICE void operator()(const ABlockWindow& a_block_window
#if FEIFEI_DEBUG
,
const BDataType* b_ptr,
int* dbg_int,
float* dbg_fp32,
void* dbg_f168
#endif
) const
{
#if FEIFEI_DEBUG
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[BLOCK ] BlockFlatmmASmemBSmemCRegV1():\n");
}
uint32_t tidx = threadIdx.x;
uint32_t tidy = threadIdx.y;
uint32_t bidx = blockIdx.x;
uint32_t bidy = blockIdx.y;
uint32_t bdmx = blockDim.x;
uint32_t bdmy = blockDim.y;
uint32_t gdmx = gridDim.x;
uint32_t gdmy = gridDim.y;
uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy + (bdmx * bdmy) * bidx + bdmx * tidy + tidx;
half_t* dbg_f16 = static_cast<half_t*>(dbg_f168);
for(int i = 0; i < DEBUG_CNT; i++)
{
dbg_int[gid * DEBUG_CNT + i] = -1;
dbg_fp32[gid * DEBUG_CNT + i] = -1.0f;
dbg_f16[gid * DEBUG_CNT + i] = ck_tile::type_convert<ck_tile::half_t>(-1.0f);
}
#endif
/*
static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> &&
std::is_same_v<BDataType, typename BBlockWindow::DataType> &&
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"wrong!");
*/
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
/*
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
*/
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
// constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
// constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Warp loop in block:
constexpr index_t kIter = 0;
constexpr index_t mIter = 0;
const auto a_warp_tensor = load_tile(a_warp_windows(number<mIter>{})(number<kIter>{}));
#if 1
// feifei TODO: Implement gemm here
#else
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
#endif
}
#else
// C += A * B
template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockWindow& a_block_window,
const BBlockWindow& b_block_window) const
{
static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> &&
std::is_same_v<BDataType, typename BBlockWindow::DataType> &&
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindow>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindow& b_block_window) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window);
return c_block_tensor;
}
#endif
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// Default policy for BlockGemmASmemBSmemCRegV1
// Default policy class should not be templated, put template on member functions instead
template <typename AType_,
typename BType_,
typename CType_,
typename BlockWarps_,
typename WarpGemm_>
struct BlockFlatmmASmemBSmemCRegV1CustomPolicy
{
using AType = remove_cvref_t<AType_>;
using BType = remove_cvref_t<BType_>;
using CType = remove_cvref_t<CType_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
using WarpGemm = remove_cvref_t<WarpGemm_>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {
// Default policy for BlockGemmASmemBSmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct BlockFlatmmASmemBSmemCRegV1DefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
constexpr index_t NumWarp = kBlockSize / get_warp_size();
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
else
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
#else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
}
}
};
} // namespace ck_tile
......@@ -85,7 +85,7 @@ struct FlatmmKernel
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return TilePartitioner::GridSize(M, N);
return TilePartitioner::GridSize(M, N); // feifei TODO: split K here
// return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
......@@ -178,7 +178,7 @@ struct FlatmmKernel
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t splitted_k;
index_t splitted_k; // problem K after splitted
};
CK_TILE_HOST static bool IsSupportedArgument(const FlatmmKernelArgs& kargs)
......@@ -473,7 +473,7 @@ struct FlatmmKernel
const BDataType* b_ptr,
int* dbg_int,
float* dbg_fp32,
short* dbg_f168
void* dbg_f168
#endif
)
{
......@@ -481,12 +481,55 @@ struct FlatmmKernel
// Create Flatmm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_shuffle_ptr, c_ptr, kargs, splitk_batch_offset);
// origin layout
// const auto& gemm_tensor_views_tuple =
// MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
// Debug origin layout
// const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
// a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const auto& gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
#if FEIFEI_DEBUG
////////////////////////////////////////////////////////
const auto& a_gemm_tensor_views = gemm_tensor_views_tuple.at(I0); // tensor_view
const auto& a_gemm_tensor_desc = a_gemm_tensor_views.desc_; // tensor_descriptor
const auto& a_gemm_buff_views = a_gemm_tensor_views.buf_; // buffer_view
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[KERNEL] a_gemm_tensor_view: size = %ld, len = [%d, %d], top = [%d, %d], upper = %d, lower = %d\n",
a_gemm_tensor_desc.get_element_space_size(),
a_gemm_tensor_desc.get_length(I0), a_gemm_tensor_desc.get_length(I1),
a_gemm_tensor_desc.get_top_dimension_hidden_ids()[0], a_gemm_tensor_desc.get_top_dimension_hidden_ids()[1],
a_gemm_tensor_desc.get_upper_dimension_hidden_idss()(I0)[0],
a_gemm_tensor_desc.get_lower_dimension_hidden_idss()(I0)[0]
);
}
const auto& a_pad_tensor_views = gemm_pad_views.at(I0); // tensor_view
const auto& a_pad_tensor_desc = a_pad_tensor_views.desc_; // tensor_descriptor
const auto& a_pad_buff_views = a_pad_tensor_views.buf_; // buffer_view
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[KERNEL] a_pad_tensor_view: size = %ld, len = [%d, %d], top = [%d, %d], upper = %d, lower = %d\n",
a_pad_tensor_desc.get_element_space_size(),
a_pad_tensor_desc.get_length(I0), a_pad_tensor_desc.get_length(I1),
a_pad_tensor_desc.get_top_dimension_hidden_ids()[0], a_pad_tensor_desc.get_top_dimension_hidden_ids()[1],
a_pad_tensor_desc.get_upper_dimension_hidden_idss()(I0)[0],
a_pad_tensor_desc.get_lower_dimension_hidden_idss()(I0)[0]
);
}
const auto& a_tile_win = gemm_tile_windows.at(I0); // tile_window_with_static_lengths
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[KERNEL] a_gemm_tile_window: dim_num = %d\n",
a_tile_win.get_num_of_dimension()
);
}
////////////////////////////////////////////////////////
#endif
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
......@@ -555,11 +598,14 @@ struct FlatmmKernel
int* dbg_int = static_cast<int*>(kargs.dbg_int_ptr);
float* dbg_fp32 = static_cast<float*>(kargs.dbg_fp32_ptr);
short* dbg_f168 = static_cast<short*>(kargs.dbg_f168_ptr);
half_t* dbg_f16 = static_cast<half_t*>(kargs.dbg_f168_ptr);
dbg_int[gid] = 1;
dbg_fp32[gid] = 1.0f;
dbg_f168[gid] = ck_tile::type_convert<ck_tile::half_t>(1.0f);
for(int i = 0; i < DEBUG_CNT; i++)
{
dbg_int[gid * DEBUG_CNT + i] = 0;
dbg_fp32[gid * DEBUG_CNT + i] = .0f;
dbg_f16[gid * DEBUG_CNT + i] = ck_tile::type_convert<ck_tile::half_t>(0.f);
}
#endif
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
......@@ -592,7 +638,7 @@ struct FlatmmKernel
b_ptr,
dbg_int,
dbg_fp32,
dbg_f168
kargs.dbg_f168_ptr
#endif
);
}
......@@ -611,7 +657,7 @@ struct FlatmmKernel
b_ptr,
dbg_int,
dbg_fp32,
dbg_f168
kargs.dbg_f168_ptr
#endif
);
}
......
......@@ -4,14 +4,14 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = FlatmmPipelineAGmemBGmemCRegV1DefaultPolicy>
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy> // feifei TODO: add default policy
struct FlatmmPipelineAGmemBGmemCRegV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
......@@ -23,7 +23,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using BlockFlatmm =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
......@@ -41,21 +42,24 @@ struct FlatmmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
return integer_divide_ceil(sizeof(ADataType) *
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>()
.get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
sizeof(BDataType) * PipelinePolicy::template MakeBLdsBlockDescriptor<Problem>()
.get_element_space_size();
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
return PipelinePolicy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return PipelinePolicy::IsTransposeC();
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
......@@ -72,7 +76,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
const BDataType* b_ptr,
int* dbg_int,
float* dbg_fp32,
short* dbg_f168
void* dbg_f168
#endif
) const
{
......@@ -80,6 +84,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] FlatmmPipelinen():\n");
printf("[PIPELN] num_loop = %d\n", num_loop);
}
uint32_t tidx = threadIdx.x;
......@@ -92,9 +97,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1
uint32_t gdmy = gridDim.y;
uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy + (bdmx * bdmy) * bidx + bdmx * tidy + tidx;
dbg_int[gid] = -1;
dbg_fp32[gid] = -1.0f;
dbg_f168[gid] = ck_tile::type_convert<ck_tile::half_t>(-1.0f);
half_t* dbg_f16 = static_cast<half_t*>(dbg_f168);
for(int i = 0; i < DEBUG_CNT; i++)
{
dbg_int[gid * DEBUG_CNT + i] = 1;
dbg_fp32[gid * DEBUG_CNT + i] = 1.0f;
dbg_f16[gid * DEBUG_CNT + i] = ck_tile::type_convert<ck_tile::half_t>(1.0f);
}
#endif
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
......@@ -108,12 +117,93 @@ struct FlatmmPipelineAGmemBGmemCRegV1
#if 1
// feifei TODO: Implement gemm here
// Get block flatmm
auto block_flatmm = BlockFlatmm(); // struct BlockFlatmmASmemBSmemCRegV1
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Prefetch -----------------------------------------------------------
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
#if FEIFEI_DEBUG // debug A global load
int a_dim = a_block_tile.get_num_of_dimension();
int a_sz = a_block_tile.get_thread_buffer_size();
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
{
printf("[PIPELN] a_dim = %d, a_sz = %d\n", a_dim, a_sz);
}
for(auto i = 0; i < a_sz; i++)
{
dbg_f16[gid * DEBUG_CNT + i] = a_block_tile.get_thread_buffer()[i];
}
return nullptr;
#endif
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// LDS write 0
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
PipelinePolicy::template MakeShuffledARegBlockDescriptor<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
// Loop ---------------------------------------------------------------
// Do flatmm
block_flatmm(a_lds_gemm_window
#if FEIFEI_DEBUG
,
b_ptr,
dbg_int,
dbg_fp32,
dbg_f168
#endif
);
// Tail ---------------------------------------------------------------
return nullptr;
#else
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
......@@ -125,7 +215,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
constexpr auto b_lds_block_desc =
PipelinePolicy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
......@@ -134,7 +225,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
PipelinePolicy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
......@@ -145,7 +236,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
PipelinePolicy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window = make_tile_window(
......@@ -184,7 +275,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegBlockDescriptor<Problem>());
PipelinePolicy::template MakeShuffledARegBlockDescriptor<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
......@@ -198,7 +289,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
PipelinePolicy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp);
......@@ -235,7 +326,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
PipelinePolicy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
......@@ -271,7 +362,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
const BDataType* b_ptr,
int* dbg_int,
float* dbg_fp32,
short* dbg_f168
void* dbg_f168
#endif
) const
{
......
......@@ -446,7 +446,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
{
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
......@@ -458,12 +458,13 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
WarpTile::at(I1),
WarpTile::at(I2),
TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
using BlockFlatmmPolicy =
BlockFlatmmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment