Commit 928b6d1a authored by coderfeli's avatar coderfeli
Browse files

split smem to 2array, but still same

parent c275904b
......@@ -41,7 +41,7 @@ CK_TILE_HOST_DEVICE static constexpr auto MakeOLdsBlockDescriptor()
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 32768; }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeODramTileDistribution()
{
......@@ -87,7 +87,7 @@ struct CShuffleEpilogueV2
// static constexpr bool kMPerBlock = 64;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536;}//kMPerBlock * kNPerBlock * sizeof(ODataType); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return kMPerBlock * kNPerBlock * sizeof(ODataType); }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
......
......@@ -161,13 +161,14 @@ struct GemmKernel
{i_n, 0});
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr_1[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
......
......@@ -165,7 +165,8 @@ struct GemmPipelineAGmemBGmemCRegV1
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
......@@ -209,26 +210,20 @@ struct GemmPipelineAGmemBGmemCRegV1
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t a_lds_block_space_size_aligned =
integer_least_multiple(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
constexpr index_t b_lds_block_space_size_aligned =
integer_least_multiple(sizeof(BDataType) * b_lds_block_desc.get_element_space_size(), 16);
// A tile in LDS view
const ADataType*__restrict__ p_a_lds0 = reinterpret_cast<ADataType*>(p_smem);
const ADataType*__restrict__ p_a_lds1 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned);
const ADataType*__restrict__ p_a_lds2 = reinterpret_cast<ADataType*>(p_smem);
const ADataType*__restrict__ p_a_lds3 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned);
const ADataType*__restrict__ p_a_lds0 = reinterpret_cast<ADataType*>(p_smem_0);
const ADataType*__restrict__ p_a_lds1 = reinterpret_cast<ADataType*>(p_smem_1);
auto a_lds_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds0, a_lds_block_desc);
auto a_lds_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds1, a_lds_block_desc);
auto a_lds_ld_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds2, a_lds_block_desc);
auto a_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds3, a_lds_block_desc);
auto a_lds_ld_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds0, a_lds_block_desc);
auto a_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds1, a_lds_block_desc);
// B tile in LDS view
const BDataType*__restrict__ p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2);
const BDataType*__restrict__ p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2 + b_lds_block_space_size_aligned);
const BDataType*__restrict__ p_b_lds2 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2);
const BDataType*__restrict__ p_b_lds3 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2 + b_lds_block_space_size_aligned);
const BDataType*__restrict__ p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem_0) + a_lds_block_space_size_aligned);
const BDataType*__restrict__ p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem_1) + a_lds_block_space_size_aligned);
auto b_lds_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc);
auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
auto b_lds_ld_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds2, b_lds_block_desc);
auto b_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds3, b_lds_block_desc);
auto b_lds_ld_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc);
auto b_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
// A LDS tile window for store
auto a_lds_window0 = make_tile_window_linear(
......@@ -392,7 +387,8 @@ struct GemmPipelineAGmemBGmemCRegV1
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
void*__restrict__ p_smem_0,
void*__restrict__ p_smem_1) const
{
return operator()(
a_dram_block_window_tmp,
......@@ -400,7 +396,8 @@ struct GemmPipelineAGmemBGmemCRegV1
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem);
p_smem_0,
p_smem_1);
}
};
......
......@@ -102,8 +102,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = integer_least_multiple(sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(), 16)
* 2;
MakeALdsBlockDescriptor<Problem>().get_element_space_size(), 16);
return smem_size_a;
}
......@@ -111,8 +110,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = integer_least_multiple(sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(), 16)
* 2;
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(), 16);
return smem_size_b;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment