Commit 2e414b7c authored by carlushuang's avatar carlushuang
Browse files

refactor length/index setting in gridwise gemm

parent b134b7d6
......@@ -128,7 +128,7 @@ struct GridwiseGemmAvx2_MxN
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
}
static auto GetAMultiIndex(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
static auto GetASliceLength(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
......@@ -146,7 +146,7 @@ struct GridwiseGemmAvx2_MxN
}
}
static auto GetBMultiIndex(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
static auto GetBSliceLength(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
{
// n_per_blk should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
......@@ -168,11 +168,49 @@ struct GridwiseGemmAvx2_MxN
}
}
static auto GetCMultiIndex(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
static auto GetCSliceLength(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
{
return ck::make_multi_index(m_per_blk, n_per_blk);
}
static auto GetAIndex(const ck::index_t i_m, const ck::index_t i_k)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(i_m, i_k);
}
else
{
// A : K, M
return ck::make_multi_index(i_k, i_m);
}
}
static auto GetBIndex(const ck::index_t i_k, const ck::index_t i_n)
{
// i_n should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(i_k, i_n);
}
else
{
// B : N/8, K, N8
return ck::make_multi_index(i_n / ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
i_k,
i_n % ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
}
}
static auto GetCIndex(const ck::index_t i_m, const ck::index_t i_n)
{
return ck::make_multi_index(i_m, i_n);
}
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc)
......@@ -260,8 +298,8 @@ struct GridwiseGemmAvx2_MxN
//
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
{
auto a_move_k_step = ck::make_multi_index(0, k_per_block);
auto b_move_k_step = ck::make_multi_index(0, k_per_block, 0);
auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = GetBIndex(k_per_block, 0);
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
const ck::index_t grid_n = math::integer_divide_ceil(GemmN, n_per_block);
......@@ -332,31 +370,19 @@ struct GridwiseGemmAvx2_MxN
nc_size = math::integer_least_multiple(
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0));
b_threadwise_copy.SetSrcSliceOrigin(
b_grid_desc,
ck::make_multi_index(
math::integer_divide_ceil(
i_nc, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
0,
0));
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0));
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(0, i_nc));
auto c_block_desc =
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
if constexpr(UseCLocalBuffer)
{
// c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
// ck::make_multi_index(i_mc, i_nc));
}
else
if constexpr(!UseCLocalBuffer)
{
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.RunRead(c_block_desc,
c_block_buf,
c_grid_desc,
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunRead(c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
c_block_desc,
c_block_buf,
GetCSliceLength(mc_size, nc_size));
}
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
......@@ -370,12 +396,12 @@ struct GridwiseGemmAvx2_MxN
a_grid_buf,
a_block_desc,
a_block_buf,
GetAMultiIndex(mc_size, kc_size));
GetASliceLength(mc_size, kc_size));
b_threadwise_copy.RunRead(b_grid_desc,
b_grid_buf,
b_block_desc,
b_block_buf,
GetBMultiIndex(kc_size, nc_size));
GetBSliceLength(kc_size, nc_size));
blockwise_gemm.Run(a_block_desc,
a_block_buf,
......@@ -395,25 +421,19 @@ struct GridwiseGemmAvx2_MxN
}
}
// if constexpr(UseCLocalBuffer)
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc, GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
GetCSliceLength(mc_size, nc_size));
}
}
}
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value)
{
auto a_move_k_step = ck::make_multi_index(0, k_per_block);
auto b_move_k_step = ck::make_multi_index(
math::integer_divide_ceil(n_per_block,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
0,
0);
auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = GetBIndex(0, n_per_block);
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
const ck::index_t grid_m_per_thread = math::integer_divide_ceil(grid_m, total_threads);
......@@ -472,7 +492,7 @@ struct GridwiseGemmAvx2_MxN
if(i_mc >= GemmM)
break;
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0));
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0));
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
{
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
......@@ -482,10 +502,9 @@ struct GridwiseGemmAvx2_MxN
a_grid_buf,
a_block_desc,
a_block_buf,
GetAMultiIndex(mc_size, kc_size));
GetASliceLength(mc_size, kc_size));
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc,
ck::make_multi_index(0, i_kc, 0));
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(i_kc, 0));
// TODO: if use local C buffer, then this nc loop need to loop only once
for(ck::index_t i_nc = 0; i_nc < GemmN; i_nc += n_per_block)
......@@ -500,7 +519,7 @@ struct GridwiseGemmAvx2_MxN
b_grid_buf,
b_block_desc,
b_block_buf,
GetBMultiIndex(kc_size, nc_size));
GetBSliceLength(kc_size, nc_size));
auto c_block_desc = UseCLocalBuffer
? GetCBlockDescriptor(mc_size, nc_size)
......@@ -508,13 +527,13 @@ struct GridwiseGemmAvx2_MxN
if constexpr(!UseCLocalBuffer)
{
c_threadwise_copy.SetSrcSliceOrigin(
c_block_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.RunRead(c_block_desc,
c_block_buf,
c_grid_desc,
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunRead(c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
c_block_desc,
c_block_buf,
GetCSliceLength(mc_size, nc_size));
}
blockwise_gemm.Run(a_block_desc,
......@@ -535,14 +554,14 @@ struct GridwiseGemmAvx2_MxN
if constexpr(UseCLocalBuffer)
{
c_threadwise_copy.SetDstSliceOrigin(
c_grid_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
GetCSliceLength(mc_size, nc_size));
}
else
{
......@@ -550,14 +569,14 @@ struct GridwiseGemmAvx2_MxN
// elementwise op from global to global
if((i_kc + k_per_block) >= GemmK)
{
c_threadwise_copy.SetDstSliceOrigin(
c_grid_desc, ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c_grid_desc,
c_grid_buf,
GetCMultiIndex(mc_size, nc_size));
GetCSliceLength(mc_size, nc_size));
}
}
}
......
......@@ -985,7 +985,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
{
if constexpr(BypassTransfer)
{
src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset;
dst_buf.p_data_ = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment