"vscode:/vscode.git/clone" did not exist on "8784a72e23538d594ea6b1bd527478fba2962d30"
Commit 05d38218 authored by carlushuang's avatar carlushuang
Browse files

fix a bug in direct conv 4G size

parent e8f639d2
......@@ -286,8 +286,8 @@ struct GridwiseDirectConvNHWCAvx2
return is_valid;
}
static ck::index_t
GetBBlockStartOffset(const BGridDesc& b_grid_desc, const index_t i_k, const index_t i_n)
static intptr_t
GetBBlockStartOffset(const BGridDesc& b_grid_desc, const intptr_t i_k, const intptr_t i_n)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
......@@ -303,13 +303,13 @@ struct GridwiseDirectConvNHWCAvx2
}
}
static ck::index_t
GetCBlockStartOffset(const CGridDesc& c_grid_desc, const index_t i_m, const index_t i_n)
static intptr_t
GetCBlockStartOffset(const CGridDesc& c_grid_desc, const intptr_t i_m, const intptr_t i_n)
{
return i_m * c_grid_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] + i_n;
}
static ck::index_t GetBLeadingElement(const BGridDesc& b_grid_desc)
static intptr_t GetBLeadingElement(const BGridDesc& b_grid_desc)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
......@@ -324,7 +324,7 @@ struct GridwiseDirectConvNHWCAvx2
}
}
static ck::index_t GetCLeadingElement(const CGridDesc& c_grid_desc)
static intptr_t GetCLeadingElement(const CGridDesc& c_grid_desc)
{
return c_grid_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
......@@ -357,25 +357,25 @@ struct GridwiseDirectConvNHWCAvx2
const auto GemmN = c_grid_desc.GetLength(I1);
const auto GemmK = a_grid_desc.GetLength(I1);
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const intptr_t Hi = input_spatial_lengths[0];
const intptr_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const intptr_t Ho = output_spatial_lengths[0];
const intptr_t Wo = output_spatial_lengths[1];
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const intptr_t Y = filter_spatial_lengths[0];
const intptr_t X = filter_spatial_lengths[1];
const index_t Sy = conv_filter_strides[0];
const index_t Sx = conv_filter_strides[1];
const intptr_t Sy = conv_filter_strides[0];
const intptr_t Sx = conv_filter_strides[1];
const index_t Dy = conv_filter_dilations[0];
const index_t Dx = conv_filter_dilations[1];
const intptr_t Dy = conv_filter_dilations[0];
const intptr_t Dx = conv_filter_dilations[1];
const index_t Py = input_left_pads[0];
const index_t Px = input_left_pads[1];
const intptr_t Py = input_left_pads[0];
const intptr_t Px = input_left_pads[1];
const index_t X_Dx = X * Dx;
const intptr_t X_Dx = X * Dx;
// const index_t Y_Dy = Y * Dy;
// const index_t InRightPadH = input_right_pads[0];
......@@ -421,11 +421,11 @@ struct GridwiseDirectConvNHWCAvx2
}
return t_;
};
const ck::index_t num_works_n = N;
const ck::index_t num_works_ho = Ho;
// const ck::index_t num_works_nho = N * Ho;
const ck::index_t num_works_wo = math::integer_divide_ceil(Wo, m_per_thread);
const ck::index_t num_works_k = math::integer_divide_ceil(K, n_per_thread);
const intptr_t num_works_n = N;
const intptr_t num_works_ho = Ho;
// const intptr_t num_works_nho = N * Ho;
const intptr_t num_works_wo = math::integer_divide_ceil(Wo, m_per_thread);
const intptr_t num_works_k = math::integer_divide_ceil(K, n_per_thread);
auto distribute_num_threads_n_ho_wo_k = [&](ck::index_t& num_threads_n_,
ck::index_t& num_threads_ho_,
......@@ -545,40 +545,41 @@ struct GridwiseDirectConvNHWCAvx2
// }
// };
for(ck::index_t i_n = tid_n * num_works_n_per_thread;
for(intptr_t i_n = tid_n * num_works_n_per_thread;
(i_n < (tid_n + 1) * num_works_n_per_thread) && i_n < num_works_n;
i_n += 1)
{
for(ck::index_t i_ho = tid_ho * num_works_ho_per_thread;
for(intptr_t i_ho = tid_ho * num_works_ho_per_thread;
(i_ho < (tid_ho + 1) * num_works_ho_per_thread) && i_ho < num_works_ho;
i_ho += 1)
{
// for input
ck::index_t i_hi_no_y = i_ho * Sy - Py;
intptr_t i_hi_no_y = i_ho * Sy - Py;
for(ck::index_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
for(intptr_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread &&
i_wo < Wo;
i_wo += m_per_thread)
{
ck::index_t current_wo_size_no_dx =
ck::math::min(Wo - i_wo, m_per_thread);
ck::index_t i_wi_no_x = i_wo * Sx - Px;
intptr_t current_wo_size_no_dx =
ck::math::min(Wo - i_wo, (intptr_t)m_per_thread);
intptr_t i_wi_no_x = i_wo * Sx - Px;
// printf("-- i_nho:%d, i_wo:%d, num_works_nho:%d,
// num_threads_nho:%d(Hi:%d,nWi:%d)\n",
// i_nho, i_wo, num_works_nho, num_threads_nho, Hi,
// Wi);fflush(stdout);
for(ck::index_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
for(intptr_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
i_k < (tid_k + 1) * num_works_k_per_thread * n_per_thread;
i_k += n_per_thread)
{
ck::index_t i_dx = 0;
ck::index_t i_dy = 0;
intptr_t i_dx = 0;
intptr_t i_dy = 0;
bool accmulate_c = false;
ck::index_t current_k_size = ck::math::min(K - i_k, n_per_thread);
intptr_t current_k_size =
ck::math::min(K - i_k, (intptr_t)n_per_thread);
auto accumulate_dy_dx = [&]() {
i_dx += Dx;
......@@ -589,25 +590,25 @@ struct GridwiseDirectConvNHWCAvx2
}
};
for(ck::index_t i_yxc = 0; i_yxc < (Y * X * C);
for(intptr_t i_yxc = 0; i_yxc < (Y * X * C);
i_yxc += C, accumulate_dy_dx())
{
ck::index_t current_i_wo = i_wo;
ck::index_t i_hi = i_hi_no_y + i_dy;
intptr_t current_i_wo = i_wo;
intptr_t i_hi = i_hi_no_y + i_dy;
if(i_hi < 0 || i_hi >= Hi)
continue;
ck::index_t i_wi = i_wi_no_x + i_dx;
ck::index_t current_wo_size = current_wo_size_no_dx;
ck::index_t pad_wo_size = 0; // when left pad, we may never have
// a chance to clear zero (like
intptr_t i_wi = i_wi_no_x + i_dx;
intptr_t current_wo_size = current_wo_size_no_dx;
intptr_t pad_wo_size = 0; // when left pad, we may never have
// a chance to clear zero (like
// padding) we need to manually clear that
if(i_wi < 0)
{
ck::index_t wi_to_zero_length =
intptr_t wi_to_zero_length =
-i_wi; // keep this a possitive number
ck::index_t steps_wo_turn_possitive =
intptr_t steps_wo_turn_possitive =
(wi_to_zero_length + Sx - 1) /
Sx; // how many steps need to move wo, to let wi to be
// possitive
......@@ -647,7 +648,7 @@ struct GridwiseDirectConvNHWCAvx2
if(pad_wo_size != 0)
{
for(ck::index_t i_wo_pad = 0; i_wo_pad < pad_wo_size;
for(intptr_t i_wo_pad = 0; i_wo_pad < pad_wo_size;
i_wo_pad++)
{
const intptr_t offset_c = GetCBlockStartOffset(
......@@ -747,28 +748,28 @@ struct GridwiseDirectConvNHWCAvx2
tid /= num_threads_wo;
const ck::index_t tid_k = tid;
for(ck::index_t i_n = tid_n * num_works_n_per_thread;
for(intptr_t i_n = tid_n * num_works_n_per_thread;
(i_n < (tid_n + 1) * num_works_n_per_thread) && i_n < num_works_n;
i_n += 1)
{
for(ck::index_t i_ho = tid_ho * num_works_ho_per_thread;
for(intptr_t i_ho = tid_ho * num_works_ho_per_thread;
(i_ho < (tid_ho + 1) * num_works_ho_per_thread) && i_ho < num_works_ho;
i_ho += 1)
{
// for input
ck::index_t i_hi_no_y = i_ho * Sy - Py;
intptr_t i_hi_no_y = i_ho * Sy - Py;
for(ck::index_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
for(intptr_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread &&
i_wo < Wo;
i_wo += m_per_thread)
{
ck::index_t current_wo_size_no_dx =
ck::math::min(Wo - i_wo, m_per_thread);
ck::index_t i_wi_no_x = i_wo * Sx - Px;
intptr_t current_wo_size_no_dx =
ck::math::min(Wo - i_wo, (intptr_t)m_per_thread);
intptr_t i_wi_no_x = i_wo * Sx - Px;
ck::index_t i_dx = 0;
ck::index_t i_dy = 0;
intptr_t i_dx = 0;
intptr_t i_dy = 0;
bool accmulate_c = false;
// printf("-- [%d] i_n:%d, i_ho:%d, i_wo:%d, num_works_n:%d,
// num_threads_n:%d(Hi:%d, Wi:%d), current_wo_size_no_dx:%d,
......@@ -785,19 +786,19 @@ struct GridwiseDirectConvNHWCAvx2
}
};
for(ck::index_t i_yxc = 0; i_yxc < (Y * X * C);
for(intptr_t i_yxc = 0; i_yxc < (Y * X * C);
i_yxc += C, accumulate_dy_dx())
{
ck::index_t current_i_wo = i_wo;
ck::index_t i_hi = i_hi_no_y + i_dy;
bool run_pad_only = false;
intptr_t current_i_wo = i_wo;
intptr_t i_hi = i_hi_no_y + i_dy;
bool run_pad_only = false;
if(i_hi < 0 || i_hi >= Hi)
continue;
ck::index_t i_wi = i_wi_no_x + i_dx;
ck::index_t current_wo_size = current_wo_size_no_dx;
ck::index_t pad_wo_size = 0; // when left pad, we may never have a
// chance to clear zero (like
intptr_t i_wi = i_wi_no_x + i_dx;
intptr_t current_wo_size = current_wo_size_no_dx;
intptr_t pad_wo_size = 0; // when left pad, we may never have a
// chance to clear zero (like
// padding) we need to manually clear that
/* left corner shift
......@@ -812,9 +813,9 @@ struct GridwiseDirectConvNHWCAvx2
*/
if(i_wi < 0)
{
ck::index_t wi_to_zero_length =
intptr_t wi_to_zero_length =
-i_wi; // keep this a possitive number
ck::index_t steps_wo_turn_possitive =
intptr_t steps_wo_turn_possitive =
(wi_to_zero_length + Sx - 1) /
Sx; // how many steps need to move wo, to let wi to be
// possitive
......@@ -859,9 +860,9 @@ struct GridwiseDirectConvNHWCAvx2
{
// manually clear zero. this may and only may need once along
// the gemm_k reduction
ck::index_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
ck::index_t current_k_block_size = ck::math::min(
K - i_k, num_works_k_per_thread * n_per_thread);
intptr_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
intptr_t current_k_block_size = ck::math::min(
K - i_k, (intptr_t)num_works_k_per_thread * n_per_thread);
const intptr_t offset_c = GetCBlockStartOffset(
c_grid_desc, (i_n * Ho + i_ho) * Wo, i_k);
......@@ -879,12 +880,12 @@ struct GridwiseDirectConvNHWCAvx2
if(run_pad_only)
continue;
for(ck::index_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
for(intptr_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
i_k < (tid_k + 1) * num_works_k_per_thread * n_per_thread;
i_k += n_per_thread)
{
ck::index_t current_k_size =
ck::math::min(K - i_k, n_per_thread);
intptr_t current_k_size =
ck::math::min(K - i_k, (intptr_t)n_per_thread);
const intptr_t offset_a = current_input_offset;
const intptr_t offset_b =
......
......@@ -1054,7 +1054,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// n * k -> n0 * k * n1, n1 = 8, n0 = n/8
for(index_t i_n_itr = 0; i_n_itr < n_per_block; i_n_itr += 8)
for(intptr_t i_n_itr = 0; i_n_itr < n_per_block; i_n_itr += 8)
{
intptr_t current_n_8 = ck::math::min(GemmN - (i_n_itr + i_gemm_n), (intptr_t)8);
intptr_t i_k_itr = k_per_block;
......@@ -1150,9 +1150,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
const float* p_src_k = p_src;
float* p_dst_k = p_dst;
for(index_t i_sub_n = 0; i_sub_n < 8; i_sub_n++)
for(intptr_t i_sub_n = 0; i_sub_n < 8; i_sub_n++)
{
for(index_t i_sub_k = 0; i_sub_k < k_per_block; i_sub_k++)
for(intptr_t i_sub_k = 0; i_sub_k < k_per_block; i_sub_k++)
{
intptr_t i_current_n_itr = i_n_itr + i_sub_n + i_gemm_n;
......@@ -1269,7 +1269,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// n0 * k * n1
index_t i_n0_itr = n0_per_block;
intptr_t i_n0_itr = n0_per_block;
while(i_n0_itr >= 8)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_n1_per_block,
......@@ -1440,7 +1440,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// k * n
index_t i_k_itr = k_per_block;
intptr_t i_k_itr = k_per_block;
while(i_k_itr >= 8)
{
avx2_util::memcpy32_avx2(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment