"vscode:/vscode.git/clone" did not exist on "a715222c69da4147ca98eca327452ed5e8d45bcb"
Commit d6d37ea9 authored by carlushuang's avatar carlushuang
Browse files

refactor Run to use slice length as block size. Fix a bug in general input copy

parent 2e414b7c
...@@ -123,9 +123,16 @@ struct GridwiseGemmAvx2_MxN ...@@ -123,9 +123,16 @@ struct GridwiseGemmAvx2_MxN
} }
} }
static auto GetCBlockDescriptor(const ck::index_t m_per_blk, const ck::index_t n_per_blk) static auto GetCBlockDescriptor(const ck::index_t m_per_blk,
const ck::index_t n_per_blk,
const CGridDesc& c_grid_desc)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk)); if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
}
else
return c_grid_desc;
} }
static auto GetASliceLength(const ck::index_t m_per_blk, const ck::index_t k_per_blk) static auto GetASliceLength(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
...@@ -264,16 +271,16 @@ struct GridwiseGemmAvx2_MxN ...@@ -264,16 +271,16 @@ struct GridwiseGemmAvx2_MxN
reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize()); reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize());
auto blockwise_gemm = BlockwiseGemmAvx2_MxN< auto blockwise_gemm = BlockwiseGemmAvx2_MxN<
FloatA, // FloatA, FloatA, // FloatA,
FloatB, // FloatB, FloatB, // FloatB,
FloatC, // FloatC, FloatC, // FloatC,
decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc, decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc,
decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc, decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block)), // CBlockDesc, decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc,
KPerBlock, // KPerBlock, KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{}; // gemm MN to utilize micro kernel>{};
int total_threads = omp_get_max_threads(); int total_threads = omp_get_max_threads();
...@@ -325,7 +332,7 @@ struct GridwiseGemmAvx2_MxN ...@@ -325,7 +332,7 @@ struct GridwiseGemmAvx2_MxN
BElementwiseOperation{}); BElementwiseOperation{});
auto c_threadwise_copy = auto c_threadwise_copy =
CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block), CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc),
ck::make_zero_multi_index<2>(), ck::make_zero_multi_index<2>(),
c_grid_desc, c_grid_desc,
ck::make_zero_multi_index<2>(), ck::make_zero_multi_index<2>(),
...@@ -373,8 +380,7 @@ struct GridwiseGemmAvx2_MxN ...@@ -373,8 +380,7 @@ struct GridwiseGemmAvx2_MxN
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0)); a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0));
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(0, i_nc)); b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(0, i_nc));
auto c_block_desc = auto c_block_desc = GetCBlockDescriptor(mc_size, nc_size, c_grid_desc);
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
if constexpr(!UseCLocalBuffer) if constexpr(!UseCLocalBuffer)
{ {
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc)); c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc));
...@@ -456,7 +462,7 @@ struct GridwiseGemmAvx2_MxN ...@@ -456,7 +462,7 @@ struct GridwiseGemmAvx2_MxN
BElementwiseOperation{}); BElementwiseOperation{});
auto c_threadwise_copy = auto c_threadwise_copy =
CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block), CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc),
ck::make_zero_multi_index<2>(), ck::make_zero_multi_index<2>(),
c_grid_desc, c_grid_desc,
ck::make_zero_multi_index<2>(), ck::make_zero_multi_index<2>(),
...@@ -521,9 +527,7 @@ struct GridwiseGemmAvx2_MxN ...@@ -521,9 +527,7 @@ struct GridwiseGemmAvx2_MxN
b_block_buf, b_block_buf,
GetBSliceLength(kc_size, nc_size)); GetBSliceLength(kc_size, nc_size));
auto c_block_desc = UseCLocalBuffer auto c_block_desc = GetCBlockDescriptor(mc_size, nc_size, c_grid_desc);
? GetCBlockDescriptor(mc_size, nc_size)
: c_grid_desc;
if constexpr(!UseCLocalBuffer) if constexpr(!UseCLocalBuffer)
{ {
......
...@@ -368,15 +368,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -368,15 +368,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{ {
if constexpr(BypassTransfer) if constexpr(BypassTransfer)
{ {
float* p_src = reinterpret_cast<float*>(src_buf.p_data_) + src_offset; dst_buf.p_data_ = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
dst_buf.p_data_ = p_src;
} }
else else
{ {
const ck::index_t m_per_block = const ck::index_t m_per_block = slice_length[Number<0>{}];
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}]; const ck::index_t k_per_block = slice_length[Number<1>{}];
const ck::index_t k_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset; const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_); float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
...@@ -540,19 +537,23 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -540,19 +537,23 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
ck::index_t i_k_itr = k_per_block; ck::index_t i_k_itr = k_per_block;
while(i_k_itr > 0) while(i_k_itr > 0)
{ {
ck::index_t current_k_block = ck::math::min(C - i_c_itr_k, k_per_block); ck::index_t current_k_block_along_c =
ck::math::min(C - i_c_itr_k, i_k_itr);
// printf("current_k_block_along_c:%d, i_c_itr_k:%d, k_per_block:%d\n",
// current_k_block_along_c, i_c_itr_k,k_per_block); fflush(stdout);
if((*reinterpret_cast<uint32_t*>(&i_hi_itr_k) < Hi) && if((*reinterpret_cast<uint32_t*>(&i_hi_itr_k) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr_k) < Wi)) (*reinterpret_cast<uint32_t*>(&i_wi_itr_k) < Wi))
avx2_util::memcpy32_avx2( avx2_util::memcpy32_avx2(
p_dst_k, p_src_k, current_k_block, element_op_); p_dst_k, p_src_k, current_k_block_along_c, element_op_);
else else
avx2_util::memset32_avx2(p_dst_k, 0, current_k_block); avx2_util::memset32_avx2(p_dst_k, 0, current_k_block_along_c);
p_dst_k += current_k_block; p_dst_k += current_k_block_along_c;
p_src_k += current_k_block; p_src_k += current_k_block_along_c;
i_c_itr_k += current_k_block; i_c_itr_k += current_k_block_along_c;
if(i_c_itr_k >= C) if(i_c_itr_k >= C)
{ {
i_c_itr_k = 0; i_c_itr_k = 0;
...@@ -569,7 +570,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -569,7 +570,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
p_src_k += input_offset_ovf_x_acc_y; p_src_k += input_offset_ovf_x_acc_y;
} }
i_k_itr -= current_k_block; i_k_itr -= current_k_block_along_c;
} }
/*** go along Gemm K ***/ /*** go along Gemm K ***/
...@@ -765,11 +766,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC ...@@ -765,11 +766,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
} }
else else
{ {
const ck::index_t n_per_block = const ck::index_t n_per_block = slice_length[Number<0>{}] * slice_length[Number<2>{}];
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] * const ck::index_t k_per_block = slice_length[Number<1>{}];
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
const ck::index_t k_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block, // printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}] // dst_desc.GetTransforms()[Number<0>{}]
...@@ -1002,7 +1000,6 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN ...@@ -1002,7 +1000,6 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if constexpr(!std::is_same<ElementwiseOperation, if constexpr(!std::is_same<ElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value) ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{ {
// if (true) {
const ck::index_t m_per_block = slice_length[Number<0>{}]; const ck::index_t m_per_block = slice_length[Number<0>{}];
const ck::index_t n_per_block = slice_length[Number<1>{}]; const ck::index_t n_per_block = slice_length[Number<1>{}];
...@@ -1073,11 +1070,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN ...@@ -1073,11 +1070,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
} }
else else
{ {
const ck::index_t m_per_block = const ck::index_t m_per_block = slice_length[Number<0>{}];
src_desc.GetTransforms()[Number<0>{}] const ck::index_t n_per_block = slice_length[Number<1>{}];
.GetUpperLengths()[Number<0>{}]; // must be multiple of 8
const ck::index_t n_per_block =
src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block); const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment