Commit 594a4c22 authored by myamlak's avatar myamlak
Browse files

Review remarks + issues with (un)signed arithmetic

parent 57f2d3c3
...@@ -864,9 +864,9 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -864,9 +864,9 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// Input tensors can't be bigger than 2GB each. // Input tensors can't be bigger than 2GB each.
constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31); constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31);
if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 || if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() > GB2 ||
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 || arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() > GB2 ||
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) > GB2) arg.c_grid_desc_m_n_.GetElementSpaceSize() > GB2)
{ {
return false; return false;
} }
......
...@@ -70,26 +70,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -70,26 +70,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) { auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0; float v_acc = 0;
for(int n = 0; n < ck::type_convert<int>(arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]); for(std::size_t n = 0; n < arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]; ++n)
++n)
{ {
for(int ho = 0; for(std::size_t ho = 0; ho < arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; ++ho)
ho < ck::type_convert<int>(arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]);
++ho)
{ {
int hi = ho * arg.conv_strides_[I0] + y * arg.conv_dilations_[I0] - auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) +
arg.in_left_pads_[I0]; ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) -
for(int wo = 0; ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
wo < ck::type_convert<int>(arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]); for(std::size_t wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo)
++wo)
{ {
int wi = wo * arg.conv_strides_[I1] + x * arg.conv_dilations_[I1] - auto wi = ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) +
arg.in_left_pads_[I1]; ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
if(hi >= 0 && if(hi >= 0 &&
hi < ck::type_convert<std::size_t>(hi) < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[2]) &&
wi >= 0 && wi >= 0 &&
wi < ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])) ck::type_convert<std::size_t>(wi) < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_out; float v_out;
float v_in; float v_in;
......
...@@ -72,21 +72,24 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -72,21 +72,24 @@ struct ReferenceConvBwdData : public device::BaseOperator
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
auto f_ncw = [&](auto n, auto c, auto wi) { auto f_ncw = [&](auto n, auto c, auto wi) {
int K = arg.weight_.mDesc.GetLengths()[0]; std::size_t K = arg.weight_.mDesc.GetLengths()[0];
int X = arg.weight_.mDesc.GetLengths()[2]; std::size_t X = arg.weight_.mDesc.GetLengths()[2];
int Wo = arg.output_.mDesc.GetLengths()[2]; std::size_t Wo = arg.output_.mDesc.GetLengths()[2];
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(int x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
int w_tmp = wi + arg.in_left_pads_[0] - x * arg.conv_dilations_[0]; auto w_tmp = ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]);
if(w_tmp % arg.conv_strides_[0] == 0) if(w_tmp % arg.conv_strides_[0] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[0]; auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
if(wo >= 0 && wo < Wo) ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; AccDataType v_out = 0;
AccDataType v_wei = 0; AccDataType v_wei = 0;
...@@ -119,33 +122,38 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -119,33 +122,38 @@ struct ReferenceConvBwdData : public device::BaseOperator
else if constexpr(NumDimSpatial == 2) else if constexpr(NumDimSpatial == 2)
{ {
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
int K = arg.weight_.mDesc.GetLengths()[0]; std::size_t K = arg.weight_.mDesc.GetLengths()[0];
int Y = arg.weight_.mDesc.GetLengths()[2]; std::size_t Y = arg.weight_.mDesc.GetLengths()[2];
int X = arg.weight_.mDesc.GetLengths()[3]; std::size_t X = arg.weight_.mDesc.GetLengths()[3];
int Ho = arg.output_.mDesc.GetLengths()[2]; std::size_t Ho = arg.output_.mDesc.GetLengths()[2];
int Wo = arg.output_.mDesc.GetLengths()[3]; std::size_t Wo = arg.output_.mDesc.GetLengths()[3];
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(int y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
int h_tmp = hi + arg.in_left_pads_[0] - y * arg.conv_dilations_[0]; auto h_tmp = ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]);
if(h_tmp % arg.conv_strides_[0] == 0) if(h_tmp % arg.conv_strides_[0] == 0)
{ {
int ho = h_tmp / arg.conv_strides_[0]; auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
if(ho >= 0 && ho < Ho) ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(int x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
int w_tmp = auto w_tmp = ck::type_convert<ck::long_index_t>(wi) +
wi + arg.in_left_pads_[1] - x * arg.conv_dilations_[1]; ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0) if(w_tmp % arg.conv_strides_[1] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[1]; auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
if(wo >= 0 && wo < Wo) ck::type_convert<ck::long_index_t>(arg.conv_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; AccDataType v_out = 0;
AccDataType v_wei = 0; AccDataType v_wei = 0;
...@@ -183,44 +191,51 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -183,44 +191,51 @@ struct ReferenceConvBwdData : public device::BaseOperator
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
{ {
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) { auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
int K = arg.weight_.mDesc.GetLengths()[0]; std::size_t K = arg.weight_.mDesc.GetLengths()[0];
int Z = arg.weight_.mDesc.GetLengths()[2]; std::size_t Z = arg.weight_.mDesc.GetLengths()[2];
int Y = arg.weight_.mDesc.GetLengths()[3]; std::size_t Y = arg.weight_.mDesc.GetLengths()[3];
int X = arg.weight_.mDesc.GetLengths()[4]; std::size_t X = arg.weight_.mDesc.GetLengths()[4];
int Do = arg.output_.mDesc.GetLengths()[2]; std::size_t Do = arg.output_.mDesc.GetLengths()[2];
int Ho = arg.output_.mDesc.GetLengths()[3]; std::size_t Ho = arg.output_.mDesc.GetLengths()[3];
int Wo = arg.output_.mDesc.GetLengths()[4]; std::size_t Wo = arg.output_.mDesc.GetLengths()[4];
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(int z = 0; z < Z; ++z) for(std::size_t z = 0; z < Z; ++z)
{ {
int d_tmp = di + arg.in_left_pads_[0] - z * arg.conv_dilations_[0]; auto d_tmp = ck::type_convert<ck::long_index_t>(di) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]);
if(d_tmp % arg.conv_strides_[0] == 0) if(d_tmp % arg.conv_strides_[0] == 0)
{ {
int do_ = d_tmp / arg.conv_strides_[0]; auto do_ = ck::type_convert<ck::long_index_t>(d_tmp) /
if(do_ >= 0 && do_ < Do) ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{ {
for(int y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
int h_tmp = auto h_tmp = ck::type_convert<ck::long_index_t>(hi) +
hi + arg.in_left_pads_[1] - y * arg.conv_dilations_[1]; ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0) if(h_tmp % arg.conv_strides_[1] == 0)
{ {
int ho = h_tmp / arg.conv_strides_[1]; auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
if(ho >= 0 && ho < Ho) ck::type_convert<ck::long_index_t>(arg.conv_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(int x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
int w_tmp = wi + arg.in_left_pads_[2] - auto w_tmp = ck::type_convert<ck::long_index_t>(wi) +
x * arg.conv_dilations_[2]; ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0) if(w_tmp % arg.conv_strides_[2] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[2]; auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
if(wo >= 0 && wo < Wo) ck::type_convert<ck::long_index_t>(arg.conv_strides_[2]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; AccDataType v_out = 0;
AccDataType v_wei = 0; AccDataType v_wei = 0;
......
...@@ -88,16 +88,15 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -88,16 +88,15 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_ncw = [&](auto n, auto k, auto wo) { auto f_ncw = [&](auto n, auto k, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]); for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
++c)
{ {
for(int x = 0; x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]); for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
++x)
{ {
int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] - auto wi = ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) +
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 && if(wi >= 0 &&
wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2])) ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -131,24 +130,22 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -131,24 +130,22 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]); for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
++c)
{ {
for(int y = 0; y < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]); for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
for(int x = 0; ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[3]); for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - auto wi = ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
hi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2]) && ck::type_convert<std::size_t>(hi) < arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[3])) ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -183,34 +180,29 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -183,34 +180,29 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]); for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
++c)
{ {
for(int z = 0; z < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]); for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
++z)
{ {
int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] - auto di = ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) -
for(int y = 0; ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
y < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[3]); for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
++y)
{ {
int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] - auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) +
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) -
for(int x = 0; ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[4]); for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
++x)
{ {
int wi = wo * arg.conv_strides_[2] + auto wi = ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[2]) +
x * arg.conv_dilations_[2] - arg.in_left_pads_[2]; ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 && if(di >= 0 &&
di < ck::type_convert<int>( ck::type_convert<std::size_t>(di) < arg.input_.mDesc.GetLengths()[2] &&
arg.input_.mDesc.GetLengths()[2]) &&
hi >= 0 && hi >= 0 &&
hi < ck::type_convert<int>( ck::type_convert<std::size_t>(hi) < arg.input_.mDesc.GetLengths()[3] &&
arg.input_.mDesc.GetLengths()[3]) &&
wi >= 0 && wi >= 0 &&
wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[4])) ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[4])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -227,7 +219,7 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -227,7 +219,7 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
} }
} }
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
......
...@@ -73,26 +73,22 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator ...@@ -73,26 +73,22 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[1]); for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
++c)
{ {
for(int y = 0; for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
y < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[2]);
++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
for(int x = 0; ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
x < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[3]); for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - auto wi = ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
hi < ck::type_convert<std::size_t>(hi) < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[2]) &&
wi >= 0 && wi >= 0 &&
wi < ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])) ck::type_convert<std::size_t>(wi) < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
......
...@@ -76,26 +76,22 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator ...@@ -76,26 +76,22 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[1]); for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
++c)
{ {
for(int y = 0; for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
y < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[2]);
++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
for(int x = 0; ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
x < ck::type_convert<int>(arg.wei_k_c_y_x_.mDesc.GetLengths()[3]); for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - auto wi = ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
hi < ck::type_convert<std::size_t>(hi) < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[2]) &&
wi >= 0 && wi >= 0 &&
wi < ck::type_convert<int>(arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])) ck::type_convert<std::size_t>(wi) < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
......
...@@ -83,7 +83,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -83,7 +83,7 @@ void profile_grouped_gemm_impl(int do_verification,
std::vector<Tensor<BDataType>> b_k_n; std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_device_results; std::vector<Tensor<CDataType>> c_m_n_device_results;
for(std::size_t i = 0; i < Ms.size(); i++) for(std::size_t i = 0; i < group_count; i++)
{ {
a_m_k.push_back( a_m_k.push_back(
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment