Commit 31720510 authored by myamlak's avatar myamlak
Browse files

Format fix

parent 594a4c22
...@@ -79,22 +79,25 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -79,22 +79,25 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo) for(std::size_t wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo)
{ {
auto wi = ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) + auto wi =
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I1]) - ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]); ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_( arg.out_element_op_(
v_out, v_out,
ck::type_convert<float>(arg.out_n_k_ho_wo_(n, k, ho, wo))); ck::type_convert<float>(arg.out_n_k_ho_wo_(n, k, ho, wo)));
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.in_n_c_hi_wi_(n, c, hi, wi))); v_in, ck::type_convert<float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
......
...@@ -144,13 +144,16 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -144,13 +144,16 @@ struct ReferenceConvBwdData : public device::BaseOperator
{ {
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
auto w_tmp = ck::type_convert<ck::long_index_t>(wi) + auto w_tmp =
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) - ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0) if(w_tmp % arg.conv_strides_[1] == 0)
{ {
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) / auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[1]); ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo) if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
...@@ -215,25 +218,34 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -215,25 +218,34 @@ struct ReferenceConvBwdData : public device::BaseOperator
{ {
for(std::size_t y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
auto h_tmp = ck::type_convert<ck::long_index_t>(hi) + auto h_tmp =
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) - ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0) if(h_tmp % arg.conv_strides_[1] == 0)
{ {
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) / auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[1]); ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho) if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
auto w_tmp = ck::type_convert<ck::long_index_t>(wi) + auto w_tmp =
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]) - ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[2]); ck::type_convert<ck::long_index_t>(
arg.in_left_pads_[2]) -
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0) if(w_tmp % arg.conv_strides_[2] == 0)
{ {
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) / auto wo =
ck::type_convert<ck::long_index_t>(arg.conv_strides_[2]); ck::type_convert<ck::long_index_t>(w_tmp) /
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo) ck::type_convert<ck::long_index_t>(
arg.conv_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
......
...@@ -92,9 +92,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -92,9 +92,10 @@ struct ReferenceConvFwd : public device::BaseOperator
{ {
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
{ {
auto wi = ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) + auto wi =
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) - ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2]) ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
{ {
...@@ -134,18 +135,22 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -134,18 +135,22 @@ struct ReferenceConvFwd : public device::BaseOperator
{ {
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
{ {
auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) + auto hi =
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) - ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
{ {
auto wi = ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) + auto wi =
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) - ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < arg.input_.mDesc.GetLengths()[2] && ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[3]) ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -184,25 +189,33 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -184,25 +189,33 @@ struct ReferenceConvFwd : public device::BaseOperator
{ {
for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z) for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
{ {
auto di = ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) + auto di =
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) - ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y) for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
{ {
auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) + auto hi =
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) - ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x) for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
{ {
auto wi = ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[2]) + auto wi =
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[2]) - ck::type_convert<ck::long_index_t>(wo *
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]); arg.conv_strides_[2]) +
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < arg.input_.mDesc.GetLengths()[2] && ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 && hi >= 0 &&
ck::type_convert<std::size_t>(hi) < arg.input_.mDesc.GetLengths()[3] && ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[4]) ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -219,7 +232,7 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -219,7 +232,7 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
} }
} }
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
......
...@@ -82,13 +82,16 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator ...@@ -82,13 +82,16 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
{ {
auto wi = ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) + auto wi =
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) - ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
......
...@@ -85,13 +85,16 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator ...@@ -85,13 +85,16 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
{ {
auto wi = ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) + auto wi =
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) - ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment