Commit 2e414b7c authored by carlushuang's avatar carlushuang
Browse files

refactor length/index setting in gridwise gemm

parent b134b7d6
...@@ -128,7 +128,7 @@ struct GridwiseGemmAvx2_MxN ...@@ -128,7 +128,7 @@ struct GridwiseGemmAvx2_MxN
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk)); return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
} }
static auto GetAMultiIndex(const ck::index_t m_per_blk, const ck::index_t k_per_blk) static auto GetASliceLength(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value) ck::tensor_layout::gemm::RowMajor>::value)
...@@ -146,7 +146,7 @@ struct GridwiseGemmAvx2_MxN ...@@ -146,7 +146,7 @@ struct GridwiseGemmAvx2_MxN
} }
} }
static auto GetBMultiIndex(const ck::index_t k_per_blk, const ck::index_t n_per_blk) static auto GetBSliceLength(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
{ {
// n_per_blk should be 8x // n_per_blk should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
...@@ -168,11 +168,49 @@ struct GridwiseGemmAvx2_MxN ...@@ -168,11 +168,49 @@ struct GridwiseGemmAvx2_MxN
} }
} }
static auto GetCMultiIndex(const ck::index_t m_per_blk, const ck::index_t n_per_blk) static auto GetCSliceLength(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
{ {
return ck::make_multi_index(m_per_blk, n_per_blk); return ck::make_multi_index(m_per_blk, n_per_blk);
} }
static auto GetAIndex(const ck::index_t i_m, const ck::index_t i_k)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(i_m, i_k);
}
else
{
// A : K, M
return ck::make_multi_index(i_k, i_m);
}
}
static auto GetBIndex(const ck::index_t i_k, const ck::index_t i_n)
{
// i_n should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(i_k, i_n);
}
else
{
// B : N/8, K, N8
return ck::make_multi_index(i_n / ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
i_k,
i_n % ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
}
}
static auto GetCIndex(const ck::index_t i_m, const ck::index_t i_n)
{
return ck::make_multi_index(i_m, i_n);
}
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc) const CGridDesc& c_grid_desc)
...@@ -260,8 +298,8 @@ struct GridwiseGemmAvx2_MxN ...@@ -260,8 +298,8 @@ struct GridwiseGemmAvx2_MxN
// //
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value) if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
{ {
auto a_move_k_step = ck::make_multi_index(0, k_per_block); auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = ck::make_multi_index(0, k_per_block, 0); auto b_move_k_step = GetBIndex(k_per_block, 0);
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block); const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
const ck::index_t grid_n = math::integer_divide_ceil(GemmN, n_per_block); const ck::index_t grid_n = math::integer_divide_ceil(GemmN, n_per_block);
...@@ -332,31 +370,19 @@ struct GridwiseGemmAvx2_MxN ...@@ -332,31 +370,19 @@ struct GridwiseGemmAvx2_MxN
nc_size = math::integer_least_multiple( nc_size = math::integer_least_multiple(
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0)); a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0));
b_threadwise_copy.SetSrcSliceOrigin( b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(0, i_nc));
b_grid_desc,
ck::make_multi_index(
math::integer_divide_ceil(
i_nc, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
0,
0));
auto c_block_desc = auto c_block_desc =
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc; UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
if constexpr(UseCLocalBuffer) if constexpr(!UseCLocalBuffer)
{
// c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
// ck::make_multi_index(i_mc, i_nc));
}
else
{ {
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc));
ck::make_multi_index(i_mc, i_nc)); c_threadwise_copy.RunRead(c_grid_desc,
c_threadwise_copy.RunRead(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf, c_grid_buf,
GetCMultiIndex(mc_size, nc_size)); c_block_desc,
c_block_buf,
GetCSliceLength(mc_size, nc_size));
} }
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block) for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
...@@ -370,12 +396,12 @@ struct GridwiseGemmAvx2_MxN ...@@ -370,12 +396,12 @@ struct GridwiseGemmAvx2_MxN
a_grid_buf, a_grid_buf,
a_block_desc, a_block_desc,
a_block_buf, a_block_buf,
GetAMultiIndex(mc_size, kc_size)); GetASliceLength(mc_size, kc_size));
b_threadwise_copy.RunRead(b_grid_desc, b_threadwise_copy.RunRead(b_grid_desc,
b_grid_buf, b_grid_buf,
b_block_desc, b_block_desc,
b_block_buf, b_block_buf,
GetBMultiIndex(kc_size, nc_size)); GetBSliceLength(kc_size, nc_size));
blockwise_gemm.Run(a_block_desc, blockwise_gemm.Run(a_block_desc,
a_block_buf, a_block_buf,
...@@ -395,25 +421,19 @@ struct GridwiseGemmAvx2_MxN ...@@ -395,25 +421,19 @@ struct GridwiseGemmAvx2_MxN
} }
} }
// if constexpr(UseCLocalBuffer) c_threadwise_copy.SetDstSliceOrigin(c_grid_desc, GetCIndex(i_mc, i_nc));
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc, c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf, c_block_buf,
c_grid_desc, c_grid_desc,
c_grid_buf, c_grid_buf,
GetCMultiIndex(mc_size, nc_size)); GetCSliceLength(mc_size, nc_size));
} }
} }
} }
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value) else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value)
{ {
auto a_move_k_step = ck::make_multi_index(0, k_per_block); auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = ck::make_multi_index( auto b_move_k_step = GetBIndex(0, n_per_block);
math::integer_divide_ceil(n_per_block,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
0,
0);
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block); const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
const ck::index_t grid_m_per_thread = math::integer_divide_ceil(grid_m, total_threads); const ck::index_t grid_m_per_thread = math::integer_divide_ceil(grid_m, total_threads);
...@@ -472,7 +492,7 @@ struct GridwiseGemmAvx2_MxN ...@@ -472,7 +492,7 @@ struct GridwiseGemmAvx2_MxN
if(i_mc >= GemmM) if(i_mc >= GemmM)
break; break;
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block); ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0)); a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0));
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block) for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
{ {
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block); ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
...@@ -482,10 +502,9 @@ struct GridwiseGemmAvx2_MxN ...@@ -482,10 +502,9 @@ struct GridwiseGemmAvx2_MxN
a_grid_buf, a_grid_buf,
a_block_desc, a_block_desc,
a_block_buf, a_block_buf,
GetAMultiIndex(mc_size, kc_size)); GetASliceLength(mc_size, kc_size));
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(i_kc, 0));
ck::make_multi_index(0, i_kc, 0));
// TODO: if use local C buffer, then this nc loop need to loop only once // TODO: if use local C buffer, then this nc loop need to loop only once
for(ck::index_t i_nc = 0; i_nc < GemmN; i_nc += n_per_block) for(ck::index_t i_nc = 0; i_nc < GemmN; i_nc += n_per_block)
...@@ -500,7 +519,7 @@ struct GridwiseGemmAvx2_MxN ...@@ -500,7 +519,7 @@ struct GridwiseGemmAvx2_MxN
b_grid_buf, b_grid_buf,
b_block_desc, b_block_desc,
b_block_buf, b_block_buf,
GetBMultiIndex(kc_size, nc_size)); GetBSliceLength(kc_size, nc_size));
auto c_block_desc = UseCLocalBuffer auto c_block_desc = UseCLocalBuffer
? GetCBlockDescriptor(mc_size, nc_size) ? GetCBlockDescriptor(mc_size, nc_size)
...@@ -508,13 +527,13 @@ struct GridwiseGemmAvx2_MxN ...@@ -508,13 +527,13 @@ struct GridwiseGemmAvx2_MxN
if constexpr(!UseCLocalBuffer) if constexpr(!UseCLocalBuffer)
{ {
c_threadwise_copy.SetSrcSliceOrigin( c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
c_block_desc, ck::make_multi_index(i_mc, i_nc)); GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunRead(c_block_desc, c_threadwise_copy.RunRead(c_grid_desc,
c_block_buf,
c_grid_desc,
c_grid_buf, c_grid_buf,
GetCMultiIndex(mc_size, nc_size)); c_block_desc,
c_block_buf,
GetCSliceLength(mc_size, nc_size));
} }
blockwise_gemm.Run(a_block_desc, blockwise_gemm.Run(a_block_desc,
...@@ -535,14 +554,14 @@ struct GridwiseGemmAvx2_MxN ...@@ -535,14 +554,14 @@ struct GridwiseGemmAvx2_MxN
if constexpr(UseCLocalBuffer) if constexpr(UseCLocalBuffer)
{ {
c_threadwise_copy.SetDstSliceOrigin( c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
c_grid_desc, ck::make_multi_index(i_mc, i_nc)); GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc, c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf, c_block_buf,
c_grid_desc, c_grid_desc,
c_grid_buf, c_grid_buf,
GetCMultiIndex(mc_size, nc_size)); GetCSliceLength(mc_size, nc_size));
} }
else else
{ {
...@@ -550,14 +569,14 @@ struct GridwiseGemmAvx2_MxN ...@@ -550,14 +569,14 @@ struct GridwiseGemmAvx2_MxN
// elementwise op from global to global // elementwise op from global to global
if((i_kc + k_per_block) >= GemmK) if((i_kc + k_per_block) >= GemmK)
{ {
c_threadwise_copy.SetDstSliceOrigin( c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
c_grid_desc, ck::make_multi_index(i_mc, i_nc)); GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc, c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf, c_block_buf,
c_grid_desc, c_grid_desc,
c_grid_buf, c_grid_buf,
GetCMultiIndex(mc_size, nc_size)); GetCSliceLength(mc_size, nc_size));
} }
} }
} }
......
...@@ -985,7 +985,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN ...@@ -985,7 +985,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
{ {
if constexpr(BypassTransfer) if constexpr(BypassTransfer)
{ {
src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset; dst_buf.p_data_ = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment