"git@developer.sourcefind.cn:change/sglang.git" did not exist on "a04efc4933b6dff875217f2191fde2a09a20b40f"
Commit 05d38218 authored by carlushuang's avatar carlushuang
Browse files

fix a bug in direct conv 4G size

parent e8f639d2
...@@ -286,8 +286,8 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -286,8 +286,8 @@ struct GridwiseDirectConvNHWCAvx2
return is_valid; return is_valid;
} }
static ck::index_t static intptr_t
GetBBlockStartOffset(const BGridDesc& b_grid_desc, const index_t i_k, const index_t i_n) GetBBlockStartOffset(const BGridDesc& b_grid_desc, const intptr_t i_k, const intptr_t i_n)
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value) ck::tensor_layout::gemm::RowMajor>::value)
...@@ -303,13 +303,13 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -303,13 +303,13 @@ struct GridwiseDirectConvNHWCAvx2
} }
} }
static ck::index_t static intptr_t
GetCBlockStartOffset(const CGridDesc& c_grid_desc, const index_t i_m, const index_t i_n) GetCBlockStartOffset(const CGridDesc& c_grid_desc, const intptr_t i_m, const intptr_t i_n)
{ {
return i_m * c_grid_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] + i_n; return i_m * c_grid_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] + i_n;
} }
static ck::index_t GetBLeadingElement(const BGridDesc& b_grid_desc) static intptr_t GetBLeadingElement(const BGridDesc& b_grid_desc)
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value) ck::tensor_layout::gemm::RowMajor>::value)
...@@ -324,7 +324,7 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -324,7 +324,7 @@ struct GridwiseDirectConvNHWCAvx2
} }
} }
static ck::index_t GetCLeadingElement(const CGridDesc& c_grid_desc) static intptr_t GetCLeadingElement(const CGridDesc& c_grid_desc)
{ {
return c_grid_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; return c_grid_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
} }
...@@ -357,25 +357,25 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -357,25 +357,25 @@ struct GridwiseDirectConvNHWCAvx2
const auto GemmN = c_grid_desc.GetLength(I1); const auto GemmN = c_grid_desc.GetLength(I1);
const auto GemmK = a_grid_desc.GetLength(I1); const auto GemmK = a_grid_desc.GetLength(I1);
const index_t Hi = input_spatial_lengths[0]; const intptr_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1]; const intptr_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0]; const intptr_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1]; const intptr_t Wo = output_spatial_lengths[1];
const index_t Y = filter_spatial_lengths[0]; const intptr_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1]; const intptr_t X = filter_spatial_lengths[1];
const index_t Sy = conv_filter_strides[0]; const intptr_t Sy = conv_filter_strides[0];
const index_t Sx = conv_filter_strides[1]; const intptr_t Sx = conv_filter_strides[1];
const index_t Dy = conv_filter_dilations[0]; const intptr_t Dy = conv_filter_dilations[0];
const index_t Dx = conv_filter_dilations[1]; const intptr_t Dx = conv_filter_dilations[1];
const index_t Py = input_left_pads[0]; const intptr_t Py = input_left_pads[0];
const index_t Px = input_left_pads[1]; const intptr_t Px = input_left_pads[1];
const index_t X_Dx = X * Dx; const intptr_t X_Dx = X * Dx;
// const index_t Y_Dy = Y * Dy; // const index_t Y_Dy = Y * Dy;
// const index_t InRightPadH = input_right_pads[0]; // const index_t InRightPadH = input_right_pads[0];
...@@ -421,11 +421,11 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -421,11 +421,11 @@ struct GridwiseDirectConvNHWCAvx2
} }
return t_; return t_;
}; };
const ck::index_t num_works_n = N; const intptr_t num_works_n = N;
const ck::index_t num_works_ho = Ho; const intptr_t num_works_ho = Ho;
// const ck::index_t num_works_nho = N * Ho; // const intptr_t num_works_nho = N * Ho;
const ck::index_t num_works_wo = math::integer_divide_ceil(Wo, m_per_thread); const intptr_t num_works_wo = math::integer_divide_ceil(Wo, m_per_thread);
const ck::index_t num_works_k = math::integer_divide_ceil(K, n_per_thread); const intptr_t num_works_k = math::integer_divide_ceil(K, n_per_thread);
auto distribute_num_threads_n_ho_wo_k = [&](ck::index_t& num_threads_n_, auto distribute_num_threads_n_ho_wo_k = [&](ck::index_t& num_threads_n_,
ck::index_t& num_threads_ho_, ck::index_t& num_threads_ho_,
...@@ -545,40 +545,41 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -545,40 +545,41 @@ struct GridwiseDirectConvNHWCAvx2
// } // }
// }; // };
for(ck::index_t i_n = tid_n * num_works_n_per_thread; for(intptr_t i_n = tid_n * num_works_n_per_thread;
(i_n < (tid_n + 1) * num_works_n_per_thread) && i_n < num_works_n; (i_n < (tid_n + 1) * num_works_n_per_thread) && i_n < num_works_n;
i_n += 1) i_n += 1)
{ {
for(ck::index_t i_ho = tid_ho * num_works_ho_per_thread; for(intptr_t i_ho = tid_ho * num_works_ho_per_thread;
(i_ho < (tid_ho + 1) * num_works_ho_per_thread) && i_ho < num_works_ho; (i_ho < (tid_ho + 1) * num_works_ho_per_thread) && i_ho < num_works_ho;
i_ho += 1) i_ho += 1)
{ {
// for input // for input
ck::index_t i_hi_no_y = i_ho * Sy - Py; intptr_t i_hi_no_y = i_ho * Sy - Py;
for(ck::index_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread; for(intptr_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread && i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread &&
i_wo < Wo; i_wo < Wo;
i_wo += m_per_thread) i_wo += m_per_thread)
{ {
ck::index_t current_wo_size_no_dx = intptr_t current_wo_size_no_dx =
ck::math::min(Wo - i_wo, m_per_thread); ck::math::min(Wo - i_wo, (intptr_t)m_per_thread);
ck::index_t i_wi_no_x = i_wo * Sx - Px; intptr_t i_wi_no_x = i_wo * Sx - Px;
// printf("-- i_nho:%d, i_wo:%d, num_works_nho:%d, // printf("-- i_nho:%d, i_wo:%d, num_works_nho:%d,
// num_threads_nho:%d(Hi:%d,nWi:%d)\n", // num_threads_nho:%d(Hi:%d,nWi:%d)\n",
// i_nho, i_wo, num_works_nho, num_threads_nho, Hi, // i_nho, i_wo, num_works_nho, num_threads_nho, Hi,
// Wi);fflush(stdout); // Wi);fflush(stdout);
for(ck::index_t i_k = tid_k * num_works_k_per_thread * n_per_thread; for(intptr_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
i_k < (tid_k + 1) * num_works_k_per_thread * n_per_thread; i_k < (tid_k + 1) * num_works_k_per_thread * n_per_thread;
i_k += n_per_thread) i_k += n_per_thread)
{ {
ck::index_t i_dx = 0; intptr_t i_dx = 0;
ck::index_t i_dy = 0; intptr_t i_dy = 0;
bool accmulate_c = false; bool accmulate_c = false;
ck::index_t current_k_size = ck::math::min(K - i_k, n_per_thread); intptr_t current_k_size =
ck::math::min(K - i_k, (intptr_t)n_per_thread);
auto accumulate_dy_dx = [&]() { auto accumulate_dy_dx = [&]() {
i_dx += Dx; i_dx += Dx;
...@@ -589,25 +590,25 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -589,25 +590,25 @@ struct GridwiseDirectConvNHWCAvx2
} }
}; };
for(ck::index_t i_yxc = 0; i_yxc < (Y * X * C); for(intptr_t i_yxc = 0; i_yxc < (Y * X * C);
i_yxc += C, accumulate_dy_dx()) i_yxc += C, accumulate_dy_dx())
{ {
ck::index_t current_i_wo = i_wo; intptr_t current_i_wo = i_wo;
ck::index_t i_hi = i_hi_no_y + i_dy; intptr_t i_hi = i_hi_no_y + i_dy;
if(i_hi < 0 || i_hi >= Hi) if(i_hi < 0 || i_hi >= Hi)
continue; continue;
ck::index_t i_wi = i_wi_no_x + i_dx; intptr_t i_wi = i_wi_no_x + i_dx;
ck::index_t current_wo_size = current_wo_size_no_dx; intptr_t current_wo_size = current_wo_size_no_dx;
ck::index_t pad_wo_size = 0; // when left pad, we may never have intptr_t pad_wo_size = 0; // when left pad, we may never have
// a chance to clear zero (like // a chance to clear zero (like
// padding) we need to manually clear that // padding) we need to manually clear that
if(i_wi < 0) if(i_wi < 0)
{ {
ck::index_t wi_to_zero_length = intptr_t wi_to_zero_length =
-i_wi; // keep this a possitive number -i_wi; // keep this a possitive number
ck::index_t steps_wo_turn_possitive = intptr_t steps_wo_turn_possitive =
(wi_to_zero_length + Sx - 1) / (wi_to_zero_length + Sx - 1) /
Sx; // how many steps need to move wo, to let wi to be Sx; // how many steps need to move wo, to let wi to be
// possitive // possitive
...@@ -647,7 +648,7 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -647,7 +648,7 @@ struct GridwiseDirectConvNHWCAvx2
if(pad_wo_size != 0) if(pad_wo_size != 0)
{ {
for(ck::index_t i_wo_pad = 0; i_wo_pad < pad_wo_size; for(intptr_t i_wo_pad = 0; i_wo_pad < pad_wo_size;
i_wo_pad++) i_wo_pad++)
{ {
const intptr_t offset_c = GetCBlockStartOffset( const intptr_t offset_c = GetCBlockStartOffset(
...@@ -747,28 +748,28 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -747,28 +748,28 @@ struct GridwiseDirectConvNHWCAvx2
tid /= num_threads_wo; tid /= num_threads_wo;
const ck::index_t tid_k = tid; const ck::index_t tid_k = tid;
for(ck::index_t i_n = tid_n * num_works_n_per_thread; for(intptr_t i_n = tid_n * num_works_n_per_thread;
(i_n < (tid_n + 1) * num_works_n_per_thread) && i_n < num_works_n; (i_n < (tid_n + 1) * num_works_n_per_thread) && i_n < num_works_n;
i_n += 1) i_n += 1)
{ {
for(ck::index_t i_ho = tid_ho * num_works_ho_per_thread; for(intptr_t i_ho = tid_ho * num_works_ho_per_thread;
(i_ho < (tid_ho + 1) * num_works_ho_per_thread) && i_ho < num_works_ho; (i_ho < (tid_ho + 1) * num_works_ho_per_thread) && i_ho < num_works_ho;
i_ho += 1) i_ho += 1)
{ {
// for input // for input
ck::index_t i_hi_no_y = i_ho * Sy - Py; intptr_t i_hi_no_y = i_ho * Sy - Py;
for(ck::index_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread; for(intptr_t i_wo = tid_wo * num_works_wo_per_thread * m_per_thread;
i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread && i_wo < (tid_wo + 1) * num_works_wo_per_thread * m_per_thread &&
i_wo < Wo; i_wo < Wo;
i_wo += m_per_thread) i_wo += m_per_thread)
{ {
ck::index_t current_wo_size_no_dx = intptr_t current_wo_size_no_dx =
ck::math::min(Wo - i_wo, m_per_thread); ck::math::min(Wo - i_wo, (intptr_t)m_per_thread);
ck::index_t i_wi_no_x = i_wo * Sx - Px; intptr_t i_wi_no_x = i_wo * Sx - Px;
ck::index_t i_dx = 0; intptr_t i_dx = 0;
ck::index_t i_dy = 0; intptr_t i_dy = 0;
bool accmulate_c = false; bool accmulate_c = false;
// printf("-- [%d] i_n:%d, i_ho:%d, i_wo:%d, num_works_n:%d, // printf("-- [%d] i_n:%d, i_ho:%d, i_wo:%d, num_works_n:%d,
// num_threads_n:%d(Hi:%d, Wi:%d), current_wo_size_no_dx:%d, // num_threads_n:%d(Hi:%d, Wi:%d), current_wo_size_no_dx:%d,
...@@ -785,19 +786,19 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -785,19 +786,19 @@ struct GridwiseDirectConvNHWCAvx2
} }
}; };
for(ck::index_t i_yxc = 0; i_yxc < (Y * X * C); for(intptr_t i_yxc = 0; i_yxc < (Y * X * C);
i_yxc += C, accumulate_dy_dx()) i_yxc += C, accumulate_dy_dx())
{ {
ck::index_t current_i_wo = i_wo; intptr_t current_i_wo = i_wo;
ck::index_t i_hi = i_hi_no_y + i_dy; intptr_t i_hi = i_hi_no_y + i_dy;
bool run_pad_only = false; bool run_pad_only = false;
if(i_hi < 0 || i_hi >= Hi) if(i_hi < 0 || i_hi >= Hi)
continue; continue;
ck::index_t i_wi = i_wi_no_x + i_dx; intptr_t i_wi = i_wi_no_x + i_dx;
ck::index_t current_wo_size = current_wo_size_no_dx; intptr_t current_wo_size = current_wo_size_no_dx;
ck::index_t pad_wo_size = 0; // when left pad, we may never have a intptr_t pad_wo_size = 0; // when left pad, we may never have a
// chance to clear zero (like // chance to clear zero (like
// padding) we need to manually clear that // padding) we need to manually clear that
/* left corner shift /* left corner shift
...@@ -812,9 +813,9 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -812,9 +813,9 @@ struct GridwiseDirectConvNHWCAvx2
*/ */
if(i_wi < 0) if(i_wi < 0)
{ {
ck::index_t wi_to_zero_length = intptr_t wi_to_zero_length =
-i_wi; // keep this a possitive number -i_wi; // keep this a possitive number
ck::index_t steps_wo_turn_possitive = intptr_t steps_wo_turn_possitive =
(wi_to_zero_length + Sx - 1) / (wi_to_zero_length + Sx - 1) /
Sx; // how many steps need to move wo, to let wi to be Sx; // how many steps need to move wo, to let wi to be
// possitive // possitive
...@@ -859,9 +860,9 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -859,9 +860,9 @@ struct GridwiseDirectConvNHWCAvx2
{ {
// manually clear zero. this may and only may need once along // manually clear zero. this may and only may need once along
// the gemm_k reduction // the gemm_k reduction
ck::index_t i_k = tid_k * num_works_k_per_thread * n_per_thread; intptr_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
ck::index_t current_k_block_size = ck::math::min( intptr_t current_k_block_size = ck::math::min(
K - i_k, num_works_k_per_thread * n_per_thread); K - i_k, (intptr_t)num_works_k_per_thread * n_per_thread);
const intptr_t offset_c = GetCBlockStartOffset( const intptr_t offset_c = GetCBlockStartOffset(
c_grid_desc, (i_n * Ho + i_ho) * Wo, i_k); c_grid_desc, (i_n * Ho + i_ho) * Wo, i_k);
...@@ -879,12 +880,12 @@ struct GridwiseDirectConvNHWCAvx2 ...@@ -879,12 +880,12 @@ struct GridwiseDirectConvNHWCAvx2
if(run_pad_only) if(run_pad_only)
continue; continue;
for(ck::index_t i_k = tid_k * num_works_k_per_thread * n_per_thread; for(intptr_t i_k = tid_k * num_works_k_per_thread * n_per_thread;
i_k < (tid_k + 1) * num_works_k_per_thread * n_per_thread; i_k < (tid_k + 1) * num_works_k_per_thread * n_per_thread;
i_k += n_per_thread) i_k += n_per_thread)
{ {
ck::index_t current_k_size = intptr_t current_k_size =
ck::math::min(K - i_k, n_per_thread); ck::math::min(K - i_k, (intptr_t)n_per_thread);
const intptr_t offset_a = current_input_offset; const intptr_t offset_a = current_input_offset;
const intptr_t offset_b = const intptr_t offset_b =
......
...@@ -1054,7 +1054,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC ...@@ -1054,7 +1054,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_); float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// n * k -> n0 * k * n1, n1 = 8, n0 = n/8 // n * k -> n0 * k * n1, n1 = 8, n0 = n/8
for(index_t i_n_itr = 0; i_n_itr < n_per_block; i_n_itr += 8) for(intptr_t i_n_itr = 0; i_n_itr < n_per_block; i_n_itr += 8)
{ {
intptr_t current_n_8 = ck::math::min(GemmN - (i_n_itr + i_gemm_n), (intptr_t)8); intptr_t current_n_8 = ck::math::min(GemmN - (i_n_itr + i_gemm_n), (intptr_t)8);
intptr_t i_k_itr = k_per_block; intptr_t i_k_itr = k_per_block;
...@@ -1150,9 +1150,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC ...@@ -1150,9 +1150,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
const float* p_src_k = p_src; const float* p_src_k = p_src;
float* p_dst_k = p_dst; float* p_dst_k = p_dst;
for(index_t i_sub_n = 0; i_sub_n < 8; i_sub_n++) for(intptr_t i_sub_n = 0; i_sub_n < 8; i_sub_n++)
{ {
for(index_t i_sub_k = 0; i_sub_k < k_per_block; i_sub_k++) for(intptr_t i_sub_k = 0; i_sub_k < k_per_block; i_sub_k++)
{ {
intptr_t i_current_n_itr = i_n_itr + i_sub_n + i_gemm_n; intptr_t i_current_n_itr = i_n_itr + i_sub_n + i_gemm_n;
...@@ -1269,7 +1269,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8 ...@@ -1269,7 +1269,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_); float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// n0 * k * n1 // n0 * k * n1
index_t i_n0_itr = n0_per_block; intptr_t i_n0_itr = n0_per_block;
while(i_n0_itr >= 8) while(i_n0_itr >= 8)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * k_n1_per_block, avx2_util::memcpy32_avx2(p_dst + 0 * k_n1_per_block,
...@@ -1440,7 +1440,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK ...@@ -1440,7 +1440,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_); float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// k * n // k * n
index_t i_k_itr = k_per_block; intptr_t i_k_itr = k_per_block;
while(i_k_itr >= 8) while(i_k_itr >= 8)
{ {
avx2_util::memcpy32_avx2( avx2_util::memcpy32_avx2(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment