Commit d8fef836 authored by carlushuang's avatar carlushuang
Browse files

fix bug for ho/wo out-of-bound access

parent 5742d293
...@@ -425,21 +425,21 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -425,21 +425,21 @@ struct GridwiseDirectConvNHWCAvx2
const ck::index_t num_works_wo = math::integer_divide_ceil(Wo, m_per_thread); const ck::index_t num_works_wo = math::integer_divide_ceil(Wo, m_per_thread);
const ck::index_t num_works_k = math::integer_divide_ceil(K, n_per_thread); const ck::index_t num_works_k = math::integer_divide_ceil(K, n_per_thread);
auto distribute_num_threads_nho_wo_k = [&](ck::index_t& num_threads_nho, auto distribute_num_threads_nho_wo_k = [&](ck::index_t& num_threads_nho_,
ck::index_t& num_threads_wo, ck::index_t& num_threads_wo_,
ck::index_t& num_threads_k) { ck::index_t& num_threads_k_) {
// TODO: only consider multiply of 2 to divide threads // TODO: only consider multiply of 2 to divide threads
ck::index_t num_threads = total_threads; ck::index_t num_threads = total_threads;
num_threads_nho = devide_thread(num_threads, num_works_nho, 2); num_threads_nho_ = devide_thread(num_threads, num_works_nho, 2);
num_threads = num_threads / num_threads_nho; num_threads = num_threads / num_threads_nho_;
num_threads_wo = devide_thread(num_threads, num_works_wo, 2); num_threads_wo_ = devide_thread(num_threads, num_works_wo, 2);
num_threads = num_threads / num_threads_wo; num_threads = num_threads / num_threads_wo_;
num_threads_k = devide_thread(num_threads, num_works_k, 2); num_threads_k_ = devide_thread(num_threads, num_works_k, 2);
// num_threads = num_threads / num_threads_k; // num_threads = num_threads / num_threads_k_;
}; };
ck::index_t num_threads_nho; ck::index_t num_threads_nho;
...@@ -456,9 +456,9 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -456,9 +456,9 @@ struct GridwiseDirectConvNHWCAvx2
math::integer_divide_ceil(num_works_k, num_threads_k); math::integer_divide_ceil(num_works_k, num_threads_k);
// printf("num_threads_nho:%d, num_threads_wo:%d, num_threads_k:%d | // printf("num_threads_nho:%d, num_threads_wo:%d, num_threads_k:%d |
// num_works_wo_per_thread:%d, num_works_k_per_thread:%d\n", // num_works_nho_per_thread:%d, num_works_wo_per_thread:%d, num_works_k_per_thread:%d\n",
// num_threads_nho, num_threads_wo, num_threads_k, num_works_wo_per_thread, // num_threads_nho, num_threads_wo, num_threads_k, num_works_nho_per_thread,
// num_works_k_per_thread); // num_works_wo_per_thread, num_works_k_per_thread); fflush(stdout);
if(dynamic_tunable.loop_over_spec == if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t:: ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t::
...@@ -523,7 +523,7 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -523,7 +523,7 @@ struct GridwiseDirectConvNHWCAvx2
} }
}; };
for(; i_nho < (tid_nho + 1) * num_works_nho_per_thread; for(; (i_nho < (tid_nho + 1) * num_works_nho_per_thread) && (i_nho < num_works_nho);
i_nho += 1, accumulate_n_ho()) i_nho += 1, accumulate_n_ho())
{ {
// for input // for input
...@@ -573,15 +573,6 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -573,15 +573,6 @@ struct GridwiseDirectConvNHWCAvx2
// chance to clear zero (like // chance to clear zero (like
// padding) we need to manually clear that // padding) we need to manually clear that
/* left corner shift
* when i_wi is negative, need shift i_wo to right to make i_wi
* possitive sx px i_wi steps_wo_turn_possitive 1 0
* 0, 1, 2.... 0 2 0 0, 2, 4... 0 1 1 -1,
* 0, 1.... 1 2 1 -1, 1, 3.... 1 2 2 -2, 0, 2... 1 2
* 3 -3, -1, 1... 2 3 1 -1, 2, 5... 1 3 2 -2,
* 1, 4.... 1 3 3 -3, 0, 3 1 3 4 -4,
* -1, 2... 2
*/
if(i_wi < 0) if(i_wi < 0)
{ {
ck::index_t wi_to_zero_length = ck::index_t wi_to_zero_length =
...@@ -603,6 +594,9 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -603,6 +594,9 @@ struct GridwiseDirectConvNHWCAvx2
Sx; // now i_wi will be a possitive number Sx; // now i_wi will be a possitive number
} }
if(i_wi >= Wi)
continue;
// shrink right wi/wo // shrink right wi/wo
if((i_wi + ((current_wo_size - 1) * Sx)) >= Wi) if((i_wi + ((current_wo_size - 1) * Sx)) >= Wi)
{ {
...@@ -623,24 +617,19 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -623,24 +617,19 @@ struct GridwiseDirectConvNHWCAvx2
if(pad_wo_size != 0) if(pad_wo_size != 0)
{ {
// manually clear zero. this may and only may need once along for(ck::index_t i_wo_pad = 0; i_wo_pad < pad_wo_size;
// the gemm_k reduction i_wo_pad++)
// ck::index_t i_k = tid_k * num_works_k_per_thread * {
// n_per_thread; ck::index_t current_k_block_size = const intptr_t offset_c = GetCBlockStartOffset(
// ck::math::min(K - i_k, num_works_k_per_thread * c_grid_desc, i_nho * Wo + i_wo_pad, i_k);
// n_per_thread);
// printf("pad_wo_size:%d, current_k_block_size:%d, clear
const intptr_t offset_c = // offset_c:%d\n",
GetCBlockStartOffset(c_grid_desc, i_nho * Wo, i_k); // pad_wo_size, current_k_size,
// offset_c);fflush(stdout);
// printf("pad_wo_size:%d, current_k_block_size:%d, clear ck::cpu::avx2_util::memset32_avx2(
// offset_c:%d\n", &c_block_buf.p_data_[offset_c], 0, current_k_size);
// pad_wo_size, current_k_size * pad_wo_size, }
// offset_c);fflush(stdout);
ck::cpu::avx2_util::memset32_avx2(
&c_block_buf.p_data_[offset_c],
0,
current_k_size * pad_wo_size);
} }
const intptr_t offset_a = current_input_offset; const intptr_t offset_a = current_input_offset;
...@@ -652,7 +641,7 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -652,7 +641,7 @@ struct GridwiseDirectConvNHWCAvx2
// printf("offset_a:%lu, offset_b:%lu, offset_c:%lu, i_n:%d, // printf("offset_a:%lu, offset_b:%lu, offset_c:%lu, i_n:%d,
// i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d, i_wo:%d, // i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d, i_wo:%d,
// current_wo_size:%d, current_k_size:%d, i_nho:%d, lda:%d, ldb:%d, // current_wo_size:%d, current_k_size:%d, i_nho:%d, lda:%d, ldb:%d,
// ldc:%d, acc:%d", // ldc:%d, acc:%d\n",
// offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx, i_dy, // offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx, i_dy,
// i_k, i_ho, current_i_wo, current_wo_size, current_k_size, // i_k, i_ho, current_i_wo, current_wo_size, current_k_size,
// i_nho, param.lda / sizeof(FloatA), param.ldb / // i_nho, param.lda / sizeof(FloatA), param.ldb /
...@@ -665,8 +654,6 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -665,8 +654,6 @@ struct GridwiseDirectConvNHWCAvx2
ThreadwiseGemm_Dispatch::Run( ThreadwiseGemm_Dispatch::Run(
&param, current_wo_size, current_k_size); &param, current_wo_size, current_k_size);
// printf(" ------ \n");fflush(stdout);
} }
} }
} }
...@@ -736,11 +723,10 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -736,11 +723,10 @@ struct GridwiseDirectConvNHWCAvx2
} }
}; };
for(; i_nho < (tid_nho + 1) * num_works_nho_per_thread; for(; (i_nho < (tid_nho + 1) * num_works_nho_per_thread) && (i_nho < num_works_nho);
i_nho += 1, accumulate_n_ho()) i_nho += 1, accumulate_n_ho())
{ {
// for input // for input
ck::index_t i_hi_no_y = i_ho * Sy - Py; ck::index_t i_hi_no_y = i_ho * Sy - Py;
for(ck::index_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread; for(ck::index_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
...@@ -753,9 +739,11 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -753,9 +739,11 @@ struct GridwiseDirectConvNHWCAvx2
ck::index_t i_dx = 0; ck::index_t i_dx = 0;
ck::index_t i_dy = 0; ck::index_t i_dy = 0;
bool accmulate_c = false; bool accmulate_c = false;
// printf("-- i_nho:%d, i_wo:%d, num_works_nho:%d, num_threads_nho:%d(Hi:%d, // printf("-- [%d] i_nho:%d, i_wo:%d, num_works_nho:%d,
// Wi:%d)\n", i_nho, i_wo, num_works_nho, num_threads_nho, Hi, // num_threads_nho:%d(Hi:%d, Wi:%d), current_wo_size_no_dx:%d,
// Wi);fflush(stdout); // m_per_thread:%d\n",
// tid, i_nho, i_wo, num_works_nho, num_threads_nho, Hi, Wi,
// current_wo_size_no_dx, m_per_thread);fflush(stdout);
auto accumulate_dy_dx = [&]() { auto accumulate_dy_dx = [&]() {
i_dx += Dx; i_dx += Dx;
...@@ -812,6 +800,9 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -812,6 +800,9 @@ struct GridwiseDirectConvNHWCAvx2
Sx; // now i_wi will be a possitive number Sx; // now i_wi will be a possitive number
} }
if(i_wi >= Wi)
continue;
// shrink right wi/wo // shrink right wi/wo
if((i_wi + ((current_wo_size - 1) * Sx)) >= Wi) if((i_wi + ((current_wo_size - 1) * Sx)) >= Wi)
{ {
...@@ -841,8 +832,10 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -841,8 +832,10 @@ struct GridwiseDirectConvNHWCAvx2
const intptr_t offset_c = const intptr_t offset_c =
GetCBlockStartOffset(c_grid_desc, i_nho * Wo, i_k); GetCBlockStartOffset(c_grid_desc, i_nho * Wo, i_k);
// printf("pad_wo_size:%d, current_k_block_size:%d, offset_c:%d\n", // printf("[%d] pad_wo_size:%d, current_k_block_size:%d,
// pad_wo_size, current_k_block_size, offset_c);fflush(stdout); // offset_c:%d\n",
// tid, pad_wo_size, current_k_block_size,
// offset_c);fflush(stdout);
ck::cpu::avx2_util::memset32_avx2(&c_block_buf.p_data_[offset_c], ck::cpu::avx2_util::memset32_avx2(&c_block_buf.p_data_[offset_c],
0, 0,
current_k_block_size * current_k_block_size *
...@@ -861,12 +854,13 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -861,12 +854,13 @@ struct GridwiseDirectConvNHWCAvx2
const intptr_t offset_c = GetCBlockStartOffset( const intptr_t offset_c = GetCBlockStartOffset(
c_grid_desc, i_nho * Wo + current_i_wo, i_k); c_grid_desc, i_nho * Wo + current_i_wo, i_k);
// printf("offset_a:%lu, offset_b:%lu, offset_c:%lu, i_n:%d, // printf("[%d] offset_a:%lu, offset_b:%lu, offset_c:%lu, i_n:%d,
// i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d, i_wo:%d, // i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d, i_wo:%d,
// current_wo_size:%d, i_nho:%d, lda:%d, ldb:%d", // current_wo_size:%d, i_nho:%d, lda:%d, ldb:%d\n",
// offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx, i_dy, // tid, offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx,
// i_k, i_ho, current_i_wo, current_wo_size, i_nho, param.lda / // i_dy, i_k, i_ho, current_i_wo, current_wo_size, i_nho,
// sizeof(FloatA), param.ldb / sizeof(FloatB)); fflush(stdout); // param.lda / sizeof(FloatA), param.ldb / sizeof(FloatB));
// fflush(stdout);
param.p_a = &a_block_buf.p_data_[offset_a]; param.p_a = &a_block_buf.p_data_[offset_a];
param.p_b = &b_block_buf.p_data_[offset_b]; param.p_b = &b_block_buf.p_data_[offset_b];
...@@ -874,8 +868,6 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -874,8 +868,6 @@ struct GridwiseDirectConvNHWCAvx2
ThreadwiseGemm_Dispatch::Run( ThreadwiseGemm_Dispatch::Run(
&param, current_wo_size, current_k_size); &param, current_wo_size, current_k_size);
// printf(" ------ \n");fflush(stdout);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment