Commit 928b6d1a authored by coderfeli's avatar coderfeli
Browse files

split smem to 2array, but still same

parent c275904b
...@@ -41,7 +41,7 @@ CK_TILE_HOST_DEVICE static constexpr auto MakeOLdsBlockDescriptor() ...@@ -41,7 +41,7 @@ CK_TILE_HOST_DEVICE static constexpr auto MakeOLdsBlockDescriptor()
} }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 32768; }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeODramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeODramTileDistribution()
{ {
...@@ -87,7 +87,7 @@ struct CShuffleEpilogueV2 ...@@ -87,7 +87,7 @@ struct CShuffleEpilogueV2
// static constexpr bool kMPerBlock = 64; // static constexpr bool kMPerBlock = 64;
static constexpr index_t kNPerBlock = Problem::kNPerBlock; static constexpr index_t kNPerBlock = Problem::kNPerBlock;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 65536;}//kMPerBlock * kNPerBlock * sizeof(ODataType); } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return kMPerBlock * kNPerBlock * sizeof(ODataType); }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size // TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ? // how do we fix this ?
...@@ -104,7 +104,7 @@ struct CShuffleEpilogueV2 ...@@ -104,7 +104,7 @@ struct CShuffleEpilogueV2
block_sync_lds(); block_sync_lds();
auto o_dram_distri = MakeODramTileDistribution<Problem>(); auto o_dram_distri = MakeODramTileDistribution<Problem>();
auto o_dram_tile = load_tile(make_tile_window(o_lds_window0, o_dram_distri)); auto o_dram_tile = load_tile(make_tile_window(o_lds_window0, o_dram_distri));
store_tile(o_dram_window_tmp, o_dram_tile); store_tile(o_dram_window_tmp, o_dram_tile);
block_sync_lds(); block_sync_lds();
} }
}; };
......
...@@ -161,13 +161,14 @@ struct GemmKernel ...@@ -161,13 +161,14 @@ struct GemmKernel
{i_n, 0}); {i_n, 0});
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr_1[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole wokrgroup. // Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile = auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr); CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() { auto c_tensor_view = [&]() {
......
...@@ -165,7 +165,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -165,7 +165,8 @@ struct GemmPipelineAGmemBGmemCRegV1
const BDramBlockWindowTmp& b_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func, const BElementFunction& b_element_func,
index_t num_loop, index_t num_loop,
void* p_smem) const void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{ {
static_assert( static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> && std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
...@@ -209,26 +210,20 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -209,26 +210,20 @@ struct GemmPipelineAGmemBGmemCRegV1
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>(); constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t a_lds_block_space_size_aligned = constexpr index_t a_lds_block_space_size_aligned =
integer_least_multiple(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16); integer_least_multiple(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
constexpr index_t b_lds_block_space_size_aligned =
integer_least_multiple(sizeof(BDataType) * b_lds_block_desc.get_element_space_size(), 16);
// A tile in LDS view // A tile in LDS view
const ADataType*__restrict__ p_a_lds0 = reinterpret_cast<ADataType*>(p_smem); const ADataType*__restrict__ p_a_lds0 = reinterpret_cast<ADataType*>(p_smem_0);
const ADataType*__restrict__ p_a_lds1 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned); const ADataType*__restrict__ p_a_lds1 = reinterpret_cast<ADataType*>(p_smem_1);
const ADataType*__restrict__ p_a_lds2 = reinterpret_cast<ADataType*>(p_smem);
const ADataType*__restrict__ p_a_lds3 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned);
auto a_lds_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds0, a_lds_block_desc); auto a_lds_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds0, a_lds_block_desc);
auto a_lds_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds1, a_lds_block_desc); auto a_lds_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds1, a_lds_block_desc);
auto a_lds_ld_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds2, a_lds_block_desc); auto a_lds_ld_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds0, a_lds_block_desc);
auto a_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds3, a_lds_block_desc); auto a_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds1, a_lds_block_desc);
// B tile in LDS view // B tile in LDS view
const BDataType*__restrict__ p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2); const BDataType*__restrict__ p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem_0) + a_lds_block_space_size_aligned);
const BDataType*__restrict__ p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2 + b_lds_block_space_size_aligned); const BDataType*__restrict__ p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem_1) + a_lds_block_space_size_aligned);
const BDataType*__restrict__ p_b_lds2 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2);
const BDataType*__restrict__ p_b_lds3 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2 + b_lds_block_space_size_aligned);
auto b_lds_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc); auto b_lds_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc);
auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc); auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
auto b_lds_ld_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds2, b_lds_block_desc); auto b_lds_ld_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc);
auto b_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds3, b_lds_block_desc); auto b_lds_ld_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
// A LDS tile window for store // A LDS tile window for store
auto a_lds_window0 = make_tile_window_linear( auto a_lds_window0 = make_tile_window_linear(
...@@ -392,7 +387,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -392,7 +387,8 @@ struct GemmPipelineAGmemBGmemCRegV1
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop, index_t num_loop,
void* p_smem) const void*__restrict__ p_smem_0,
void*__restrict__ p_smem_1) const
{ {
return operator()( return operator()(
a_dram_block_window_tmp, a_dram_block_window_tmp,
...@@ -400,7 +396,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -400,7 +396,8 @@ struct GemmPipelineAGmemBGmemCRegV1
b_dram_block_window_tmp, b_dram_block_window_tmp,
[](const BDataType& b) { return b; }, [](const BDataType& b) { return b; },
num_loop, num_loop,
p_smem); p_smem_0,
p_smem_1);
} }
}; };
......
...@@ -102,8 +102,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -102,8 +102,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{ {
constexpr index_t smem_size_a = integer_least_multiple(sizeof(typename Problem::ADataType) * constexpr index_t smem_size_a = integer_least_multiple(sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(), 16) MakeALdsBlockDescriptor<Problem>().get_element_space_size(), 16);
* 2;
return smem_size_a; return smem_size_a;
} }
...@@ -111,8 +110,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -111,8 +110,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{ {
constexpr index_t smem_size_b = integer_least_multiple(sizeof(typename Problem::BDataType) * constexpr index_t smem_size_b = integer_least_multiple(sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(), 16) MakeBLdsBlockDescriptor<Problem>().get_element_space_size(), 16);
* 2;
return smem_size_b; return smem_size_b;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment