Commit 6ffd41ae authored by carlushuang's avatar carlushuang
Browse files

fix a bug when upsampling value

parent d8fef836
...@@ -421,19 +421,25 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -421,19 +421,25 @@ struct GridwiseDirectConvNHWCAvx2
} }
return t_; return t_;
}; };
const ck::index_t num_works_nho = N * Ho; const ck::index_t num_works_n = N;
const ck::index_t num_works_wo = math::integer_divide_ceil(Wo, m_per_thread); const ck::index_t num_works_ho = Ho;
const ck::index_t num_works_k = math::integer_divide_ceil(K, n_per_thread); // const ck::index_t num_works_nho = N * Ho;
const ck::index_t num_works_wo = math::integer_divide_ceil(Wo, m_per_thread);
auto distribute_num_threads_nho_wo_k = [&](ck::index_t& num_threads_nho_, const ck::index_t num_works_k = math::integer_divide_ceil(K, n_per_thread);
ck::index_t& num_threads_wo_,
ck::index_t& num_threads_k_) { auto distribute_num_threads_n_ho_wo_k = [&](ck::index_t& num_threads_n_,
ck::index_t& num_threads_ho_,
ck::index_t& num_threads_wo_,
ck::index_t& num_threads_k_) {
// TODO: only consider multiply of 2 to divide threads // TODO: only consider multiply of 2 to divide threads
ck::index_t num_threads = total_threads; ck::index_t num_threads = total_threads;
num_threads_nho_ = devide_thread(num_threads, num_works_nho, 2); num_threads_n_ = devide_thread(num_threads, num_works_n, 2);
num_threads = num_threads / num_threads_nho_; num_threads = num_threads / num_threads_n_;
num_threads_ho_ = devide_thread(num_threads, num_works_ho, 2);
num_threads = num_threads / num_threads_ho_;
num_threads_wo_ = devide_thread(num_threads, num_works_wo, 2); num_threads_wo_ = devide_thread(num_threads, num_works_wo, 2);
num_threads = num_threads / num_threads_wo_; num_threads = num_threads / num_threads_wo_;
...@@ -442,14 +448,18 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -442,14 +448,18 @@ struct GridwiseDirectConvNHWCAvx2
// num_threads = num_threads / num_threads_k_; // num_threads = num_threads / num_threads_k_;
}; };
ck::index_t num_threads_nho; ck::index_t num_threads_n;
ck::index_t num_threads_ho;
ck::index_t num_threads_wo; ck::index_t num_threads_wo;
ck::index_t num_threads_k; ck::index_t num_threads_k;
distribute_num_threads_nho_wo_k(num_threads_nho, num_threads_wo, num_threads_k); distribute_num_threads_n_ho_wo_k(
num_threads_n, num_threads_ho, num_threads_wo, num_threads_k);
const ck::index_t num_works_nho_per_thread = const ck::index_t num_works_n_per_thread =
math::integer_divide_ceil(num_works_nho, num_threads_nho); math::integer_divide_ceil(num_works_n, num_threads_n);
const ck::index_t num_works_ho_per_thread =
math::integer_divide_ceil(num_works_ho, num_threads_ho);
const ck::index_t num_works_wo_per_thread = const ck::index_t num_works_wo_per_thread =
math::integer_divide_ceil(num_works_wo, num_threads_wo); math::integer_divide_ceil(num_works_wo, num_threads_wo);
const ck::index_t num_works_k_per_thread = const ck::index_t num_works_k_per_thread =
...@@ -460,6 +470,14 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -460,6 +470,14 @@ struct GridwiseDirectConvNHWCAvx2
// num_threads_nho, num_threads_wo, num_threads_k, num_works_nho_per_thread, // num_threads_nho, num_threads_wo, num_threads_k, num_works_nho_per_thread,
// num_works_wo_per_thread, num_works_k_per_thread); fflush(stdout); // num_works_wo_per_thread, num_works_k_per_thread); fflush(stdout);
if((X - 1) * Dx + 1 <= Px || (Y - 1) * Dy + 1 <= Py)
{
// padding zero case, outpout will have zero due to upsampling
// TODO: This is ugly and slow
ck::cpu::avx2_util::memset32_avx2(&c_grid_buf.p_data_[0], 0, N * Ho * Wo * K);
// printf("___ clear\n");
}
if(dynamic_tunable.loop_over_spec == if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t:: ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t::
LoopOver_MNK) LoopOver_MNK)
...@@ -495,10 +513,14 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -495,10 +513,14 @@ struct GridwiseDirectConvNHWCAvx2
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC) UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
: c_grid_desc.GetElementSpaceSize()); : c_grid_desc.GetElementSpaceSize());
const ck::index_t tid = omp_get_thread_num(); ck::index_t tid = omp_get_thread_num();
const ck::index_t tid_k = tid % num_threads_k; const ck::index_t tid_n = tid % num_threads_n;
const ck::index_t tid_wo = (tid / num_threads_k) % num_threads_wo; tid /= num_threads_n;
const ck::index_t tid_nho = tid / (num_threads_k * num_threads_wo); const ck::index_t tid_ho = tid % num_threads_ho;
tid /= num_threads_ho;
const ck::index_t tid_wo = tid % num_threads_wo;
tid /= num_threads_wo;
const ck::index_t tid_k = tid;
ck::cpu::ThreadwiseGemmParam param; ck::cpu::ThreadwiseGemmParam param;
// param.Kr = k_per_block; // param.Kr = k_per_block;
...@@ -510,150 +532,161 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -510,150 +532,161 @@ struct GridwiseDirectConvNHWCAvx2
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h // ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w // iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
ck::index_t i_nho = tid_nho * num_works_nho_per_thread; // ck::index_t i_nho = tid_nho * num_works_nho_per_thread;
ck::index_t i_ho = i_nho % Ho; // ck::index_t i_ho = i_nho % Ho;
ck::index_t i_n = i_nho / Ho; // ck::index_t i_n = i_nho / Ho;
auto accumulate_n_ho = [&]() { // auto accumulate_n_ho = [&]() {
i_ho++; // i_ho++;
if(i_ho >= Wo) // if(i_ho >= Wo)
{ // {
i_ho = 0; // i_ho = 0;
i_n++; // i_n++;
} // }
}; // };
for(; (i_nho < (tid_nho + 1) * num_works_nho_per_thread) && (i_nho < num_works_nho); for(ck::index_t i_n = tid_n * num_works_n_per_thread;
i_nho += 1, accumulate_n_ho()) (i_n < (tid_n + 1) * num_works_n_per_thread) && i_n < num_works_n;
i_n += 1)
{ {
// for input for(ck::index_t i_ho = tid_ho * num_works_ho_per_thread;
ck::index_t i_hi_no_y = i_ho * Sy - Py; (i_ho < (tid_ho + 1) * num_works_ho_per_thread) && i_ho < num_works_ho;
i_ho += 1)
for(ck::index_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread;
i_wo += m_per_thread)
{ {
ck::index_t current_wo_size_no_dx = ck::math::min(Wo - i_wo, m_per_thread); // for input
ck::index_t i_wi_no_x = i_wo * Sx - Px; ck::index_t i_hi_no_y = i_ho * Sy - Py;
// printf("-- i_nho:%d, i_wo:%d, num_works_nho:%d,
// num_threads_nho:%d(Hi:%d,nWi:%d)\n",
// i_nho, i_wo, num_works_nho, num_threads_nho, Hi, Wi);fflush(stdout);
for(ck::index_t i_k = tid_k * num_works_k_per_thread * n_per_thread; for(ck::index_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
i_k < (tid_k + 1) * num_works_k_per_thread * n_per_thread; i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread &&
i_k += n_per_thread) i_wo < Wo;
i_wo += m_per_thread)
{ {
ck::index_t i_dx = 0; ck::index_t current_wo_size_no_dx =
ck::index_t i_dy = 0; ck::math::min(Wo - i_wo, m_per_thread);
bool accmulate_c = false; ck::index_t i_wi_no_x = i_wo * Sx - Px;
ck::index_t current_k_size = ck::math::min(K - i_k, n_per_thread); // printf("-- i_nho:%d, i_wo:%d, num_works_nho:%d,
// num_threads_nho:%d(Hi:%d,nWi:%d)\n",
// i_nho, i_wo, num_works_nho, num_threads_nho, Hi,
// Wi);fflush(stdout);
auto accumulate_dy_dx = [&]() { for(ck::index_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
i_dx += Dx; i_k < (tid_k + 1) * num_works_k_per_thread * n_per_thread;
if(i_dx >= X_Dx) i_k += n_per_thread)
{
i_dx = 0;
i_dy += Dy;
}
};
for(ck::index_t i_yxc = 0; i_yxc < (Y * X * C);
i_yxc += C, accumulate_dy_dx())
{ {
ck::index_t current_i_wo = i_wo; ck::index_t i_dx = 0;
ck::index_t i_hi = i_hi_no_y + i_dy; ck::index_t i_dy = 0;
if(i_hi < 0 || i_hi >= Hi) bool accmulate_c = false;
continue;
ck::index_t i_wi = i_wi_no_x + i_dx; ck::index_t current_k_size = ck::math::min(K - i_k, n_per_thread);
ck::index_t current_wo_size = current_wo_size_no_dx;
ck::index_t pad_wo_size = 0; // when left pad, we may never have a
// chance to clear zero (like
// padding) we need to manually clear that
if(i_wi < 0) auto accumulate_dy_dx = [&]() {
{ i_dx += Dx;
ck::index_t wi_to_zero_length = if(i_dx >= X_Dx)
-i_wi; // keep this a possitive number {
ck::index_t steps_wo_turn_possitive = i_dx = 0;
(wi_to_zero_length + Sx - 1) / i_dy += Dy;
Sx; // how many steps need to move wo, to let wi to be }
// possitive };
current_wo_size -= steps_wo_turn_possitive; for(ck::index_t i_yxc = 0; i_yxc < (Y * X * C);
if(current_wo_size <= 0) i_yxc += C, accumulate_dy_dx())
{
ck::index_t current_i_wo = i_wo;
ck::index_t i_hi = i_hi_no_y + i_dy;
if(i_hi < 0 || i_hi >= Hi)
continue; continue;
current_i_wo += steps_wo_turn_possitive;
if(!accmulate_c)
pad_wo_size =
steps_wo_turn_possitive; // if already accumulating, no
// need to manually set
i_wi += steps_wo_turn_possitive *
Sx; // now i_wi will be a possitive number
}
if(i_wi >= Wi) ck::index_t i_wi = i_wi_no_x + i_dx;
continue; ck::index_t current_wo_size = current_wo_size_no_dx;
ck::index_t pad_wo_size = 0; // when left pad, we may never have
// a chance to clear zero (like
// padding) we need to manually clear that
// shrink right wi/wo if(i_wi < 0)
if((i_wi + ((current_wo_size - 1) * Sx)) >= Wi) {
{ ck::index_t wi_to_zero_length =
// printf(" ->[r] i_wi:%d, r:%d(%d), ", i_wi, i_wi + -i_wi; // keep this a possitive number
// ((current_wo_size - 1) * Sx), current_wo_size); ck::index_t steps_wo_turn_possitive =
current_wo_size = (wi_to_zero_length + Sx - 1) /
(Wi - 1 - i_wi) / Sx + 1; // NOTE: this be careful why here Sx; // how many steps need to move wo, to let wi to be
// should be compute like this. // possitive
if(current_wo_size <= 0)
current_wo_size -= steps_wo_turn_possitive;
if(current_wo_size <= 0)
continue;
current_i_wo += steps_wo_turn_possitive;
if(!accmulate_c)
pad_wo_size =
steps_wo_turn_possitive; // if already accumulating,
// no need to manually set
i_wi += steps_wo_turn_possitive *
Sx; // now i_wi will be a possitive number
}
if(i_wi >= Wi)
continue; continue;
}
param.accmulate_c = accmulate_c ? 1 : 0; // shrink right wi/wo
accmulate_c = true; if((i_wi + ((current_wo_size - 1) * Sx)) >= Wi)
{
// printf(" ->[r] i_wi:%d, r:%d(%d), ", i_wi, i_wi +
// ((current_wo_size - 1) * Sx), current_wo_size);
current_wo_size = (Wi - 1 - i_wi) / Sx +
1; // NOTE: this be careful why here
// should be compute like this.
if(current_wo_size <= 0)
continue;
}
intptr_t current_input_offset = param.accmulate_c = accmulate_c ? 1 : 0;
i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C; accmulate_c = true;
if(pad_wo_size != 0) intptr_t current_input_offset =
{ i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C;
for(ck::index_t i_wo_pad = 0; i_wo_pad < pad_wo_size;
i_wo_pad++) if(pad_wo_size != 0)
{ {
const intptr_t offset_c = GetCBlockStartOffset( for(ck::index_t i_wo_pad = 0; i_wo_pad < pad_wo_size;
c_grid_desc, i_nho * Wo + i_wo_pad, i_k); i_wo_pad++)
{
// printf("pad_wo_size:%d, current_k_block_size:%d, clear const intptr_t offset_c = GetCBlockStartOffset(
// offset_c:%d\n", c_grid_desc,
// pad_wo_size, current_k_size, (i_n * Ho + i_ho) * Wo + i_wo_pad,
// offset_c);fflush(stdout); i_k);
ck::cpu::avx2_util::memset32_avx2(
&c_block_buf.p_data_[offset_c], 0, current_k_size); // printf("pad_wo_size:%d, current_k_block_size:%d,
// clear offset_c:%d\n",
// pad_wo_size, current_k_size,
// offset_c);fflush(stdout);
ck::cpu::avx2_util::memset32_avx2(
&c_block_buf.p_data_[offset_c], 0, current_k_size);
}
} }
}
const intptr_t offset_a = current_input_offset; const intptr_t offset_a = current_input_offset;
const intptr_t offset_b = const intptr_t offset_b =
GetBBlockStartOffset(b_grid_desc, i_yxc, i_k); GetBBlockStartOffset(b_grid_desc, i_yxc, i_k);
const intptr_t offset_c = GetCBlockStartOffset( const intptr_t offset_c = GetCBlockStartOffset(
c_grid_desc, i_nho * Wo + current_i_wo, i_k); c_grid_desc, (i_n * Ho + i_ho) * Wo + current_i_wo, i_k);
// printf("offset_a:%lu, offset_b:%lu, offset_c:%lu, i_n:%d, // printf("offset_a:%lu, offset_b:%lu, offset_c:%lu, i_n:%d,
// i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d, i_wo:%d, // i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d, i_wo:%d,
// current_wo_size:%d, current_k_size:%d, i_nho:%d, lda:%d, ldb:%d, // current_wo_size:%d, current_k_size:%d, i_nho:%d, lda:%d,
// ldc:%d, acc:%d\n", // ldb:%d, ldc:%d, acc:%d\n",
// offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx, i_dy, // offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx,
// i_k, i_ho, current_i_wo, current_wo_size, current_k_size, // i_dy, i_k, i_ho, current_i_wo, current_wo_size,
// i_nho, param.lda / sizeof(FloatA), param.ldb / // current_k_size, i_nho, param.lda / sizeof(FloatA),
// sizeof(FloatB), param.ldc / sizeof(FloatC), // param.ldb / sizeof(FloatB), param.ldc / sizeof(FloatC),
// param.accmulate_c); fflush(stdout); // param.accmulate_c); fflush(stdout);
param.p_a = &a_block_buf.p_data_[offset_a]; param.p_a = &a_block_buf.p_data_[offset_a];
param.p_b = &b_block_buf.p_data_[offset_b]; param.p_b = &b_block_buf.p_data_[offset_b];
param.p_c = &c_block_buf.p_data_[offset_c]; param.p_c = &c_block_buf.p_data_[offset_c];
ThreadwiseGemm_Dispatch::Run( ThreadwiseGemm_Dispatch::Run(
&param, current_wo_size, current_k_size); &param, current_wo_size, current_k_size);
}
} }
} }
} }
...@@ -695,11 +728,6 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -695,11 +728,6 @@ struct GridwiseDirectConvNHWCAvx2
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC) UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
: c_grid_desc.GetElementSpaceSize()); : c_grid_desc.GetElementSpaceSize());
const ck::index_t tid = omp_get_thread_num();
const ck::index_t tid_k = tid % num_threads_k;
const ck::index_t tid_wo = (tid / num_threads_k) % num_threads_wo;
const ck::index_t tid_nho = tid / (num_threads_k * num_threads_wo);
ck::cpu::ThreadwiseGemmParam param; ck::cpu::ThreadwiseGemmParam param;
// param.Kr = k_per_block; // param.Kr = k_per_block;
param.lda = Sx * C * sizeof(FloatA); param.lda = Sx * C * sizeof(FloatA);
...@@ -710,164 +738,176 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -710,164 +738,176 @@ struct GridwiseDirectConvNHWCAvx2
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h // ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w // iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
ck::index_t i_nho = tid_nho * num_works_nho_per_thread; ck::index_t tid = omp_get_thread_num();
ck::index_t i_ho = i_nho % Ho; const ck::index_t tid_n = tid % num_threads_n;
ck::index_t i_n = i_nho / Ho; tid /= num_threads_n;
const ck::index_t tid_ho = tid % num_threads_ho;
auto accumulate_n_ho = [&]() { tid /= num_threads_ho;
i_ho++; const ck::index_t tid_wo = tid % num_threads_wo;
if(i_ho >= Wo) tid /= num_threads_wo;
{ const ck::index_t tid_k = tid;
i_ho = 0;
i_n++; for(ck::index_t i_n = tid_n * num_works_n_per_thread;
} (i_n < (tid_n + 1) * num_works_n_per_thread) && i_n < num_works_n;
}; i_n += 1)
for(; (i_nho < (tid_nho + 1) * num_works_nho_per_thread) && (i_nho < num_works_nho);
i_nho += 1, accumulate_n_ho())
{ {
// for input for(ck::index_t i_ho = tid_ho * num_works_ho_per_thread;
ck::index_t i_hi_no_y = i_ho * Sy - Py; (i_ho < (tid_ho + 1) * num_works_ho_per_thread) && i_ho < num_works_ho;
i_ho += 1)
for(ck::index_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread;
i_wo += m_per_thread)
{ {
ck::index_t current_wo_size_no_dx = ck::math::min(Wo - i_wo, m_per_thread); // for input
ck::index_t i_wi_no_x = i_wo * Sx - Px; ck::index_t i_hi_no_y = i_ho * Sy - Py;
ck::index_t i_dx = 0;
ck::index_t i_dy = 0;
bool accmulate_c = false;
// printf("-- [%d] i_nho:%d, i_wo:%d, num_works_nho:%d,
// num_threads_nho:%d(Hi:%d, Wi:%d), current_wo_size_no_dx:%d,
// m_per_thread:%d\n",
// tid, i_nho, i_wo, num_works_nho, num_threads_nho, Hi, Wi,
// current_wo_size_no_dx, m_per_thread);fflush(stdout);
auto accumulate_dy_dx = [&]() {
i_dx += Dx;
if(i_dx >= X_Dx)
{
i_dx = 0;
i_dy += Dy;
}
};
for(ck::index_t i_yxc = 0; i_yxc < (Y * X * C); for(ck::index_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
i_yxc += C, accumulate_dy_dx()) i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread &&
i_wo < Wo;
i_wo += m_per_thread)
{ {
ck::index_t current_i_wo = i_wo; ck::index_t current_wo_size_no_dx =
ck::index_t i_hi = i_hi_no_y + i_dy; ck::math::min(Wo - i_wo, m_per_thread);
if(i_hi < 0 || i_hi >= Hi) ck::index_t i_wi_no_x = i_wo * Sx - Px;
continue;
ck::index_t i_dx = 0;
ck::index_t i_wi = i_wi_no_x + i_dx; ck::index_t i_dy = 0;
ck::index_t current_wo_size = current_wo_size_no_dx; bool accmulate_c = false;
ck::index_t pad_wo_size = // printf("-- [%d] i_n:%d, i_ho:%d, i_wo:%d, num_works_n:%d,
0; // when left pad, we may never have a chance to clear zero (like // num_threads_n:%d(Hi:%d, Wi:%d), current_wo_size_no_dx:%d,
// padding) we need to manually clear that // m_per_thread:%d\n",
// tid, i_n, i_ho, i_wo, num_works_n, num_threads_n, Hi, Wi,
/* left corner shift // current_wo_size_no_dx, m_per_thread);fflush(stdout);
* when i_wi is negative, need shift i_wo to right to make i_wi
* possitive sx px i_wi steps_wo_turn_possitive 1 0 0, auto accumulate_dy_dx = [&]() {
* 1, 2.... 0 2 0 0, 2, 4... 0 1 1 -1, 0, 1.... 1 i_dx += Dx;
* 2 1 -1, 1, 3.... 1 if(i_dx >= X_Dx)
* 2 2 -2, 0, 2... 1 {
* 2 3 -3, -1, 1... 2 i_dx = 0;
* 3 1 -1, 2, 5... 1 i_dy += Dy;
* 3 2 -2, 1, 4.... 1 }
* 3 3 -3, 0, 3 1 };
* 3 4 -4, -1, 2... 2
*/ for(ck::index_t i_yxc = 0; i_yxc < (Y * X * C);
if(i_wi < 0) i_yxc += C, accumulate_dy_dx())
{ {
ck::index_t wi_to_zero_length = ck::index_t current_i_wo = i_wo;
-i_wi; // keep this a possitive number ck::index_t i_hi = i_hi_no_y + i_dy;
ck::index_t steps_wo_turn_possitive = bool run_pad_only = false;
(wi_to_zero_length + Sx - 1) / if(i_hi < 0 || i_hi >= Hi)
Sx; // how many steps need to move wo, to let wi to be possitive
current_wo_size -= steps_wo_turn_possitive;
if(current_wo_size <= 0)
continue; continue;
current_i_wo += steps_wo_turn_possitive;
if(!accmulate_c)
pad_wo_size =
steps_wo_turn_possitive; // if already accumulating, no need
// to manually set
i_wi += steps_wo_turn_possitive *
Sx; // now i_wi will be a possitive number
}
if(i_wi >= Wi) ck::index_t i_wi = i_wi_no_x + i_dx;
continue; ck::index_t current_wo_size = current_wo_size_no_dx;
ck::index_t pad_wo_size = 0; // when left pad, we may never have a
// chance to clear zero (like
// padding) we need to manually clear that
// shrink right wi/wo /* left corner shift
if((i_wi + ((current_wo_size - 1) * Sx)) >= Wi) * when i_wi is negative, need shift i_wo to right to make i_wi
{ * possitive sx px i_wi steps_wo_turn_possitive
// printf(" ->[r] i_wi:%d, r:%d(%d), ", i_wi, i_wi + * 1 0
// ((current_wo_size - 1) * Sx), current_wo_size); * 0, 1, 2.... 0 2 0 0, 2, 4... 0 1 1 -1,
current_wo_size = * 0, 1.... 1 2 1 -1, 1, 3.... 1 2 2 -2, 0, 2... 1 2
(Wi - 1 - i_wi) / Sx + 1; // NOTE: this be careful why here * 3 -3, -1, 1... 2 3 1 -1, 2, 5... 1 3 2 -2,
// should be compute like this. * 1, 4.... 1 3 3 -3, 0, 3 1 3 4 -4,
if(current_wo_size <= 0) * -1, 2... 2
*/
if(i_wi < 0)
{
ck::index_t wi_to_zero_length =
-i_wi; // keep this a possitive number
ck::index_t steps_wo_turn_possitive =
(wi_to_zero_length + Sx - 1) /
Sx; // how many steps need to move wo, to let wi to be
// possitive
current_wo_size -= steps_wo_turn_possitive;
// printf("--- current_wo_size:%d, i_wi:%d\n", current_wo_size,
// i_wi);
if(current_wo_size <= 0)
continue;
current_i_wo += steps_wo_turn_possitive;
if(!accmulate_c)
pad_wo_size =
steps_wo_turn_possitive; // if already accumulating, no
// need to manually set
i_wi += steps_wo_turn_possitive *
Sx; // now i_wi will be a possitive number
}
if(i_wi >= Wi)
{
continue; continue;
} }
// shrink right wi/wo
if((i_wi + ((current_wo_size - 1) * Sx)) >= Wi)
{
// printf(" ->[r] i_wi:%d, r:%d(%d), ", i_wi, i_wi +
// ((current_wo_size - 1) * Sx), current_wo_size);
current_wo_size =
(Wi - 1 - i_wi) / Sx + 1; // NOTE: this be careful why here
// should be compute like this.
if(current_wo_size <= 0)
continue;
}
param.accmulate_c = accmulate_c ? 1 : 0; param.accmulate_c = accmulate_c ? 1 : 0;
accmulate_c = true; accmulate_c = true;
intptr_t current_input_offset = intptr_t current_input_offset =
i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C; i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C;
if(pad_wo_size != 0) if(pad_wo_size != 0)
{ {
// manually clear zero. this may and only may need once along the // manually clear zero. this may and only may need once along
// gemm_k reduction // the gemm_k reduction
ck::index_t i_k = tid_k * num_works_k_per_thread * n_per_thread; ck::index_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
ck::index_t current_k_block_size = ck::index_t current_k_block_size = ck::math::min(
ck::math::min(K - i_k, num_works_k_per_thread * n_per_thread); K - i_k, num_works_k_per_thread * n_per_thread);
const intptr_t offset_c = const intptr_t offset_c = GetCBlockStartOffset(
GetCBlockStartOffset(c_grid_desc, i_nho * Wo, i_k); c_grid_desc, (i_n * Ho + i_ho) * Wo, i_k);
// printf("[%d] pad_wo_size:%d, current_k_block_size:%d, // printf("[%d] pad_wo_size:%d, current_k_block_size:%d,
// offset_c:%d\n", // offset_c:%d, i_wo:%d\n",
// tid, pad_wo_size, current_k_block_size, // tid, pad_wo_size, current_k_block_size, offset_c,
// offset_c);fflush(stdout); // i_wo);fflush(stdout);
ck::cpu::avx2_util::memset32_avx2(&c_block_buf.p_data_[offset_c], ck::cpu::avx2_util::memset32_avx2(
0, &c_block_buf.p_data_[offset_c],
current_k_block_size * 0,
pad_wo_size); current_k_block_size * pad_wo_size);
} }
for(ck::index_t i_k = tid_k * num_works_k_per_thread * n_per_thread; if(run_pad_only)
i_k < (tid_k + 1) * num_works_k_per_thread * n_per_thread; continue;
i_k += n_per_thread)
{
ck::index_t current_k_size = ck::math::min(K - i_k, n_per_thread);
const intptr_t offset_a = current_input_offset; for(ck::index_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
const intptr_t offset_b = i_k < (tid_k + 1) * num_works_k_per_thread * n_per_thread;
GetBBlockStartOffset(b_grid_desc, i_yxc, i_k); i_k += n_per_thread)
const intptr_t offset_c = GetCBlockStartOffset( {
c_grid_desc, i_nho * Wo + current_i_wo, i_k); ck::index_t current_k_size =
ck::math::min(K - i_k, n_per_thread);
// printf("[%d] offset_a:%lu, offset_b:%lu, offset_c:%lu, i_n:%d,
// i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d, i_wo:%d, const intptr_t offset_a = current_input_offset;
// current_wo_size:%d, i_nho:%d, lda:%d, ldb:%d\n", const intptr_t offset_b =
// tid, offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx, GetBBlockStartOffset(b_grid_desc, i_yxc, i_k);
// i_dy, i_k, i_ho, current_i_wo, current_wo_size, i_nho, const intptr_t offset_c = GetCBlockStartOffset(
// param.lda / sizeof(FloatA), param.ldb / sizeof(FloatB)); c_grid_desc, (i_n * Ho + i_ho) * Wo + current_i_wo, i_k);
// fflush(stdout);
// printf("[%d] offset_a:%lu, offset_b:%lu, offset_c:%lu,
param.p_a = &a_block_buf.p_data_[offset_a]; // i_n:%d, i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d,
param.p_b = &b_block_buf.p_data_[offset_b]; // i_wo:%d, current_wo_size:%d, i_n:%d, i_ho:%d, lda:%d,
param.p_c = &c_block_buf.p_data_[offset_c]; // ldb:%d\n",
// tid, offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx,
ThreadwiseGemm_Dispatch::Run( // i_dy, i_k, i_ho, current_i_wo, current_wo_size, i_n,
&param, current_wo_size, current_k_size); // i_ho, param.lda / sizeof(FloatA), param.ldb /
// sizeof(FloatB)); fflush(stdout);
param.p_a = &a_block_buf.p_data_[offset_a];
param.p_b = &b_block_buf.p_data_[offset_b];
param.p_c = &c_block_buf.p_data_[offset_c];
ThreadwiseGemm_Dispatch::Run(
&param, current_wo_size, current_k_size);
}
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment