Commit d6d37ea9 authored by carlushuang's avatar carlushuang
Browse files

refactor Run to use slice length as block size. Fix a bug in general input copy

parent 2e414b7c
......@@ -123,9 +123,16 @@ struct GridwiseGemmAvx2_MxN
}
}
static auto GetCBlockDescriptor(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
static auto GetCBlockDescriptor(const ck::index_t m_per_blk,
const ck::index_t n_per_blk,
const CGridDesc& c_grid_desc)
{
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
}
else
return c_grid_desc;
}
static auto GetASliceLength(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
......@@ -264,16 +271,16 @@ struct GridwiseGemmAvx2_MxN
reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize());
auto blockwise_gemm = BlockwiseGemmAvx2_MxN<
FloatA, // FloatA,
FloatB, // FloatB,
FloatC, // FloatC,
decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc,
decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block)), // CBlockDesc,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
FloatA, // FloatA,
FloatB, // FloatB,
FloatC, // FloatC,
decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc,
decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
int total_threads = omp_get_max_threads();
......@@ -325,7 +332,7 @@ struct GridwiseGemmAvx2_MxN
BElementwiseOperation{});
auto c_threadwise_copy =
CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block),
CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc),
ck::make_zero_multi_index<2>(),
c_grid_desc,
ck::make_zero_multi_index<2>(),
......@@ -373,8 +380,7 @@ struct GridwiseGemmAvx2_MxN
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0));
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(0, i_nc));
auto c_block_desc =
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
auto c_block_desc = GetCBlockDescriptor(mc_size, nc_size, c_grid_desc);
if constexpr(!UseCLocalBuffer)
{
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc));
......@@ -456,7 +462,7 @@ struct GridwiseGemmAvx2_MxN
BElementwiseOperation{});
auto c_threadwise_copy =
CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block),
CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc),
ck::make_zero_multi_index<2>(),
c_grid_desc,
ck::make_zero_multi_index<2>(),
......@@ -521,9 +527,7 @@ struct GridwiseGemmAvx2_MxN
b_block_buf,
GetBSliceLength(kc_size, nc_size));
auto c_block_desc = UseCLocalBuffer
? GetCBlockDescriptor(mc_size, nc_size)
: c_grid_desc;
auto c_block_desc = GetCBlockDescriptor(mc_size, nc_size, c_grid_desc);
if constexpr(!UseCLocalBuffer)
{
......
......@@ -368,15 +368,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{
if constexpr(BypassTransfer)
{
float* p_src = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
dst_buf.p_data_ = p_src;
dst_buf.p_data_ = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
}
else
{
const ck::index_t m_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
const ck::index_t k_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
const ck::index_t m_per_block = slice_length[Number<0>{}];
const ck::index_t k_per_block = slice_length[Number<1>{}];
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
......@@ -540,19 +537,23 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
ck::index_t i_k_itr = k_per_block;
while(i_k_itr > 0)
{
ck::index_t current_k_block = ck::math::min(C - i_c_itr_k, k_per_block);
ck::index_t current_k_block_along_c =
ck::math::min(C - i_c_itr_k, i_k_itr);
// printf("current_k_block_along_c:%d, i_c_itr_k:%d, k_per_block:%d\n",
// current_k_block_along_c, i_c_itr_k,k_per_block); fflush(stdout);
if((*reinterpret_cast<uint32_t*>(&i_hi_itr_k) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr_k) < Wi))
avx2_util::memcpy32_avx2(
p_dst_k, p_src_k, current_k_block, element_op_);
p_dst_k, p_src_k, current_k_block_along_c, element_op_);
else
avx2_util::memset32_avx2(p_dst_k, 0, current_k_block);
avx2_util::memset32_avx2(p_dst_k, 0, current_k_block_along_c);
p_dst_k += current_k_block;
p_src_k += current_k_block;
p_dst_k += current_k_block_along_c;
p_src_k += current_k_block_along_c;
i_c_itr_k += current_k_block;
i_c_itr_k += current_k_block_along_c;
if(i_c_itr_k >= C)
{
i_c_itr_k = 0;
......@@ -569,7 +570,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
p_src_k += input_offset_ovf_x_acc_y;
}
i_k_itr -= current_k_block;
i_k_itr -= current_k_block_along_c;
}
/*** go along Gemm K ***/
......@@ -765,11 +766,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
}
else
{
const ck::index_t n_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] *
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
const ck::index_t k_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
const ck::index_t n_per_block = slice_length[Number<0>{}] * slice_length[Number<2>{}];
const ck::index_t k_per_block = slice_length[Number<1>{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}]
......@@ -1002,7 +1000,6 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if constexpr(!std::is_same<ElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
// if (true) {
const ck::index_t m_per_block = slice_length[Number<0>{}];
const ck::index_t n_per_block = slice_length[Number<1>{}];
......@@ -1073,11 +1070,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
}
else
{
const ck::index_t m_per_block =
src_desc.GetTransforms()[Number<0>{}]
.GetUpperLengths()[Number<0>{}]; // must be multiple of 8
const ck::index_t n_per_block =
src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
const ck::index_t m_per_block = slice_length[Number<0>{}];
const ck::index_t n_per_block = slice_length[Number<1>{}];
const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment