Commit 97f133d2 authored by myamlak's avatar myamlak
Browse files

Explicit static_cast to ck::type_convert

parent b411ee3b
...@@ -88,13 +88,13 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -88,13 +88,13 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_ncw = [&](auto n, auto k, auto wo) { auto f_ncw = [&](auto n, auto k, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < static_cast<int>(arg.weight_.mDesc.GetLengths()[1]); ++c) for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]); ++c)
{ {
for(int x = 0; x < static_cast<int>(arg.weight_.mDesc.GetLengths()[2]); ++x) for(int x = 0; x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]); ++x)
{ {
int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] - int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] -
arg.in_left_pads_[0]; arg.in_left_pads_[0];
if(wi >= 0 && wi < static_cast<int>(arg.input_.mDesc.GetLengths()[2])) if(wi >= 0 && wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2]))
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -128,18 +128,18 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -128,18 +128,18 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < static_cast<int>(arg.weight_.mDesc.GetLengths()[1]); ++c) for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]); ++c)
{ {
for(int y = 0; y < static_cast<int>(arg.weight_.mDesc.GetLengths()[2]); ++y) for(int y = 0; y < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]); ++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0]; arg.in_left_pads_[0];
for(int x = 0; x < static_cast<int>(arg.weight_.mDesc.GetLengths()[3]); ++x) for(int x = 0; x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[3]); ++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1]; arg.in_left_pads_[1];
if(hi >= 0 && hi < static_cast<int>(arg.input_.mDesc.GetLengths()[2]) && wi >= 0 && if(hi >= 0 && hi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2]) && wi >= 0 &&
wi < static_cast<int>(arg.input_.mDesc.GetLengths()[3])) wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[3]))
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -174,23 +174,23 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -174,23 +174,23 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < static_cast<int>(arg.weight_.mDesc.GetLengths()[1]); ++c) for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]); ++c)
{ {
for(int z = 0; z < static_cast<int>(arg.weight_.mDesc.GetLengths()[2]); ++z) for(int z = 0; z < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]); ++z)
{ {
int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] - int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] -
arg.in_left_pads_[0]; arg.in_left_pads_[0];
for(int y = 0; y < static_cast<int>(arg.weight_.mDesc.GetLengths()[3]); ++y) for(int y = 0; y < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[3]); ++y)
{ {
int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] - int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] -
arg.in_left_pads_[1]; arg.in_left_pads_[1];
for(int x = 0; x < static_cast<int>(arg.weight_.mDesc.GetLengths()[4]); ++x) for(int x = 0; x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[4]); ++x)
{ {
int wi = wo * arg.conv_strides_[2] + int wi = wo * arg.conv_strides_[2] +
x * arg.conv_dilations_[2] - arg.in_left_pads_[2]; x * arg.conv_dilations_[2] - arg.in_left_pads_[2];
if(di >= 0 && di < static_cast<int>(arg.input_.mDesc.GetLengths()[2]) && if(di >= 0 && di < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2]) &&
hi >= 0 && hi < static_cast<int>(arg.input_.mDesc.GetLengths()[3]) && hi >= 0 && hi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[3]) &&
wi >= 0 && wi < static_cast<int>(arg.input_.mDesc.GetLengths()[4])) wi >= 0 && wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[4]))
{ {
float v_in; float v_in;
float v_wei; float v_wei;
......
...@@ -145,12 +145,12 @@ struct ConvParams ...@@ -145,12 +145,12 @@ struct ConvParams
input_left_pads(left_pads), input_left_pads(left_pads),
input_right_pads(right_pads) input_right_pads(right_pads)
{ {
if(static_cast<ck::index_t>(filter_spatial_lengths.size()) != num_dim_spatial || if(ck::type_convert<ck::index_t>(filter_spatial_lengths.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_spatial_lengths.size()) != num_dim_spatial || ck::type_convert<ck::index_t>(input_spatial_lengths.size()) != num_dim_spatial ||
static_cast<ck::index_t>(conv_filter_strides.size()) != num_dim_spatial || ck::type_convert<ck::index_t>(conv_filter_strides.size()) != num_dim_spatial ||
static_cast<ck::index_t>(conv_filter_dilations.size()) != num_dim_spatial || ck::type_convert<ck::index_t>(conv_filter_dilations.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_left_pads.size()) != num_dim_spatial || ck::type_convert<ck::index_t>(input_left_pads.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_right_pads.size()) != num_dim_spatial) ck::type_convert<ck::index_t>(input_right_pads.size()) != num_dim_spatial)
{ {
throw(std::runtime_error( throw(std::runtime_error(
"ConvParams::GetOutputSpatialLengths: " "ConvParams::GetOutputSpatialLengths: "
...@@ -174,12 +174,12 @@ struct ConvParams ...@@ -174,12 +174,12 @@ struct ConvParams
std::vector<ck::index_t> GetOutputSpatialLengths() const std::vector<ck::index_t> GetOutputSpatialLengths() const
{ {
if(static_cast<ck::index_t>(filter_spatial_lengths.size()) != num_dim_spatial || if(ck::type_convert<ck::index_t>(filter_spatial_lengths.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_spatial_lengths.size()) != num_dim_spatial || ck::type_convert<ck::index_t>(input_spatial_lengths.size()) != num_dim_spatial ||
static_cast<ck::index_t>(conv_filter_strides.size()) != num_dim_spatial || ck::type_convert<ck::index_t>(conv_filter_strides.size()) != num_dim_spatial ||
static_cast<ck::index_t>(conv_filter_dilations.size()) != num_dim_spatial || ck::type_convert<ck::index_t>(conv_filter_dilations.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_left_pads.size()) != num_dim_spatial || ck::type_convert<ck::index_t>(input_left_pads.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_right_pads.size()) != num_dim_spatial) ck::type_convert<ck::index_t>(input_right_pads.size()) != num_dim_spatial)
{ {
throw(std::runtime_error( throw(std::runtime_error(
"ConvParams::GetOutputSpatialLengths: " "ConvParams::GetOutputSpatialLengths: "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment