Commit 6c270303 authored by dummycoderfe's avatar dummycoderfe
Browse files

change pipelines to v4. compile ok

parent c808fa65
......@@ -6,8 +6,10 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV2
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
// C += A * B
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
// M->N Warp
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// check ABC-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"A distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_DEVICE static constexpr auto MakeABlockDistribution()
{
// M->N Warp
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
return a_block_dstr;
// return make_static_distributed_tensor<ADataType>(a_block_dstr);
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistribution()
{
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
return b_block_dstr;
// return make_static_distributed_tensor<BDataType>(b_block_dstr);
}
// Prefetch lds
template <typename BlockWindowTmp, typename BlockTensor>
CK_TILE_DEVICE static auto PrefetchLds(const BlockWindowTmp& block_window, BlockTensor& block_tensor)
{
auto tileDist = BlockTensor::get_tile_distribution();//.get_static_tile_distribution_encoding()
return load_tile(block_tensor, make_tile_window(block_window, tileDist));
}
// C = A * B
template <typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
return c_block_tensor;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {
// Default policy for BlockGemmARegBRegCRegV2
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmARegBRegCRegV2DefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16{}, 2, 2);
}
}
};
} // namespace ck_tile
......@@ -148,36 +148,6 @@ struct BlockGemmASmemBSmemCRegV1
});
});
});
// constexpr auto c_warp_y_lengths =
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// // hot loop:
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// // read A warp tensor from A block window
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// // read B warp tensor from B Block window
// // const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// // read C warp tensor from C block tensor
// CWarpTensor c_warp_tensor;
// c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// // warp GEMM
// WG{}(c_warp_tensor, a_warp_tensor(mIter, kIter), b_warp_tensor(nIter, kIter));
// // write C warp tensor into C block tensor
// c_block_tensor.set_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
// c_warp_tensor.get_thread_buffer());
// });
// });
// });
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
......
......@@ -39,20 +39,38 @@ struct GemmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
return integer_least_multiple(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) * 2 +
integer_least_multiple(
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
16) * 2;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const
{
load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, kKPerBlock});
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
......@@ -75,23 +93,23 @@ struct GemmPipelineAGmemBGmemCRegV1
"wrong!");
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
integer_least_multiple(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
constexpr index_t b_lds_block_space_size_aligned =
integer_least_multiple(sizeof(BDataType) * b_lds_block_desc.get_element_space_size(), 16);
ADataType* p_a_lds0 = reinterpret_cast<ADataType*>(p_smem);
ADataType* p_a_lds1 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned);
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
BDataType* p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2);
BDataType* p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_b_lds0) + b_lds_block_space_size_aligned);
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
auto a_lds_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds0, a_lds_block_desc);
auto b_lds_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc);
auto a_lds_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds1, a_lds_block_desc);
auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
......@@ -101,8 +119,10 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_store_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_store_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B DRAM tile window for load
auto b_copy_dram_window =
......@@ -112,143 +132,144 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_store_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_store_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_load_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_load_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_load_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_load_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
auto c_block_tile = Policy::template BlockGemm<Problem>::MakeCBlockTile();
// a b register tile
auto a_prefetch_tile0 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
auto a_prefetch_tile1 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
auto b_prefetch_tile0 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
auto b_prefetch_tile1 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile;
// prefetch
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func);
block_sync_lds();
// global read 1
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
// local prefetch 0
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_prefetch_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_prefetch_tile0);
// LDS write 1
LocalPrefill(a_store_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window1, b_global_load_tile, b_element_func);
// global read 2
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
index_t iCounter = num_loop - 1;
while(iCounter > 2)
{
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// ping
{
block_sync_lds();
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_prefetch_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_prefetch_tile1);
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_prefetch_tile0, b_prefetch_tile0);
// LDS write 0
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegBlockDescriptor<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
__builtin_amdgcn_sched_barrier(0);
// pong
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_prefetch_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_prefetch_tile0);
LocalPrefill(a_store_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window1, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_prefetch_tile1, b_prefetch_tile1);
}
iCounter -= 2;
}
// LDS write 0
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
//tail 3
if (iCounter == 1) {
// 3
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp);
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_prefetch_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_prefetch_tile1);
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_prefetch_tile0, b_prefetch_tile0);
__builtin_amdgcn_sched_barrier(0);
}
else
// 2
{
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_prefetch_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_prefetch_tile0);
block_gemm(c_block_tile, a_prefetch_tile1, b_prefetch_tile1);
__builtin_amdgcn_sched_barrier(0);
}
}
// __syncthreads();
// if (threadIdx.x == 0) {
// for (int j = 0; j < 256; j++) {
// for(int i = 0; i < 32; i++) {
// int ik0 = i /8;
// int ik1 = i % 8;
// printf("%f,", type_convert<float>(p_b_lds[ik1 + j * 8 + ik0 * 8 * 256]));
// }
// printf("\n");
// }
// }
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
block_sync_lds();
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// LDS write i + 1
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
//1
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
block_gemm(c_block_tile, a_prefetch_tile0, b_prefetch_tile0);
}
else
//tail 2
} else {
{
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_prefetch_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_prefetch_tile1);
block_gemm(c_block_tile, a_prefetch_tile0, b_prefetch_tile0);
__builtin_amdgcn_sched_barrier(0);
}
// 2
{
block_gemm(c_block_tile, a_prefetch_tile1, b_prefetch_tile1);
}
iCounter--;
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if(abs(type_convert<float>(c_block_tile(i_j_idx)) - 32) > 0.1)
// printf("%d %f,", threadIdx.x, type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
return c_block_tile;
}
......@@ -268,4 +289,189 @@ struct GemmPipelineAGmemBGmemCRegV1
}
};
// __device__ static constexpr auto HotLoopScheduler()
// {
// // schedule
// constexpr auto num_ds_read_inst =
// HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num;
// constexpr auto num_ds_write_inst =
// HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
// ;
// constexpr auto num_buffer_load_inst =
// HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num;
// ;
// constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
// constexpr auto num_issue = num_buffer_load_inst;
// static_for<0, num_issue, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(
// 0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(
// 0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA
// });
// }
// CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
// {
// constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});
// constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});
// constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});
// constexpr index_t WaveSize = 64;
// constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});
// constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});
// constexpr index_t A_LDS_Read_Width = KPerXDL;
// constexpr index_t B_LDS_Read_Width = KPerXDL;
// constexpr index_t A_Buffer_Load_Inst_Num =
// MPerBlock * KPerBlock / (BlockSize * VectorSizeA);
// constexpr index_t B_Buffer_Load_Inst_Num =
// NPerBlock * KPerBlock / (BlockSize * VectorSizeB);
// constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t A_LDS_Read_Inst_Num =
// WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t B_LDS_Read_Inst_Num =
// WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
// (BlockSize / WaveSize) /
// (MPerXDL * NPerXDL * KPerXDL);
// // A/B split schedule
// // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
// constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
// ? A_LDS_Read_Inst_Num
// : A_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
// ? B_LDS_Read_Inst_Num
// : B_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
// constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
// constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
// constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
// constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
// constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
// constexpr auto ds_read_a_issue_cycle =
// A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
// constexpr auto ds_read_b_issue_cycle =
// B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
// constexpr auto ds_read_a_mfma_rate =
// (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
// constexpr auto ds_read_b_mfma_rate =
// (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
// constexpr auto num_dsread_b_mfma =
// (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// // stage 1
// // Separate this part?
// // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// // sizeof(ComputeDataType) /
// // sizeof(BDataType)
// // ? sizeof(ComputeDataType) /
// // sizeof(ADataType) : sizeof(ComputeDataType)
// // / sizeof(BDataType);
// constexpr auto num_mfma_stage1 =
// num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
// constexpr auto num_mfma_per_issue =
// num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
// constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
// constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
// static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
// ignore = i;
// static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
// ignore = idswrite;
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
// });
// static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
// ignore = i;
// static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
// ignore = idswrite;
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
// });
// // stage 2
// static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
// if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
// ds_read_a_mfma_rate)
// {
// __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
// }
// else
// {
// __builtin_amdgcn_sched_group_barrier(
// 0x100,
// num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
// 0); // DS read
// }
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
// if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
// ds_read_b_mfma_rate)
// {
// __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
// }
// else
// {
// __builtin_amdgcn_sched_group_barrier(
// 0x100,
// num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
// 0); // DS read
// }
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// }
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if(abs(type_convert<float>(c_block_tile(i_j_idx)) - 32) > 0.1)
// printf("%d %f,", threadIdx.x, type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
} // namespace ck_tile
......@@ -11,6 +11,9 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using BlockGemmPolicy = BlockGemmARegBRegCRegV2DefaultPolicy;
template <typename Problem>
using BlockGemm = BlockGemmARegBRegCRegV2<Problem, BlockGemmPolicy>;
#if 0
// 2d
......@@ -472,9 +475,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy;
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
return BlockGemm<Problem>{};
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment