"test/vscode:/vscode.git/clone" did not exist on "0e92deb717cb67149a603bf5f15e1c304432dbf7"
Commit 97f133d2 authored by myamlak's avatar myamlak
Browse files

Explicit static_cast to ck::type_convert

parent b411ee3b
......@@ -88,13 +88,13 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_ncw = [&](auto n, auto k, auto wo) {
float v_acc = 0;
for(int c = 0; c < static_cast<int>(arg.weight_.mDesc.GetLengths()[1]); ++c)
for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]); ++c)
{
for(int x = 0; x < static_cast<int>(arg.weight_.mDesc.GetLengths()[2]); ++x)
for(int x = 0; x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]); ++x)
{
int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
if(wi >= 0 && wi < static_cast<int>(arg.input_.mDesc.GetLengths()[2]))
if(wi >= 0 && wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2]))
{
float v_in;
float v_wei;
......@@ -128,18 +128,18 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < static_cast<int>(arg.weight_.mDesc.GetLengths()[1]); ++c)
for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]); ++c)
{
for(int y = 0; y < static_cast<int>(arg.weight_.mDesc.GetLengths()[2]); ++y)
for(int y = 0; y < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]); ++y)
{
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int x = 0; x < static_cast<int>(arg.weight_.mDesc.GetLengths()[3]); ++x)
for(int x = 0; x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[3]); ++x)
{
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
if(hi >= 0 && hi < static_cast<int>(arg.input_.mDesc.GetLengths()[2]) && wi >= 0 &&
wi < static_cast<int>(arg.input_.mDesc.GetLengths()[3]))
if(hi >= 0 && hi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2]) && wi >= 0 &&
wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[3]))
{
float v_in;
float v_wei;
......@@ -174,23 +174,23 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < static_cast<int>(arg.weight_.mDesc.GetLengths()[1]); ++c)
for(int c = 0; c < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[1]); ++c)
{
for(int z = 0; z < static_cast<int>(arg.weight_.mDesc.GetLengths()[2]); ++z)
for(int z = 0; z < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[2]); ++z)
{
int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int y = 0; y < static_cast<int>(arg.weight_.mDesc.GetLengths()[3]); ++y)
for(int y = 0; y < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[3]); ++y)
{
int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
for(int x = 0; x < static_cast<int>(arg.weight_.mDesc.GetLengths()[4]); ++x)
for(int x = 0; x < ck::type_convert<int>(arg.weight_.mDesc.GetLengths()[4]); ++x)
{
int wi = wo * arg.conv_strides_[2] +
x * arg.conv_dilations_[2] - arg.in_left_pads_[2];
if(di >= 0 && di < static_cast<int>(arg.input_.mDesc.GetLengths()[2]) &&
hi >= 0 && hi < static_cast<int>(arg.input_.mDesc.GetLengths()[3]) &&
wi >= 0 && wi < static_cast<int>(arg.input_.mDesc.GetLengths()[4]))
if(di >= 0 && di < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[2]) &&
hi >= 0 && hi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[3]) &&
wi >= 0 && wi < ck::type_convert<int>(arg.input_.mDesc.GetLengths()[4]))
{
float v_in;
float v_wei;
......
......@@ -145,12 +145,12 @@ struct ConvParams
input_left_pads(left_pads),
input_right_pads(right_pads)
{
if(static_cast<ck::index_t>(filter_spatial_lengths.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_spatial_lengths.size()) != num_dim_spatial ||
static_cast<ck::index_t>(conv_filter_strides.size()) != num_dim_spatial ||
static_cast<ck::index_t>(conv_filter_dilations.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_left_pads.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_right_pads.size()) != num_dim_spatial)
if(ck::type_convert<ck::index_t>(filter_spatial_lengths.size()) != num_dim_spatial ||
ck::type_convert<ck::index_t>(input_spatial_lengths.size()) != num_dim_spatial ||
ck::type_convert<ck::index_t>(conv_filter_strides.size()) != num_dim_spatial ||
ck::type_convert<ck::index_t>(conv_filter_dilations.size()) != num_dim_spatial ||
ck::type_convert<ck::index_t>(input_left_pads.size()) != num_dim_spatial ||
ck::type_convert<ck::index_t>(input_right_pads.size()) != num_dim_spatial)
{
throw(std::runtime_error(
"ConvParams::GetOutputSpatialLengths: "
......@@ -174,12 +174,12 @@ struct ConvParams
std::vector<ck::index_t> GetOutputSpatialLengths() const
{
if(static_cast<ck::index_t>(filter_spatial_lengths.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_spatial_lengths.size()) != num_dim_spatial ||
static_cast<ck::index_t>(conv_filter_strides.size()) != num_dim_spatial ||
static_cast<ck::index_t>(conv_filter_dilations.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_left_pads.size()) != num_dim_spatial ||
static_cast<ck::index_t>(input_right_pads.size()) != num_dim_spatial)
if(ck::type_convert<ck::index_t>(filter_spatial_lengths.size()) != num_dim_spatial ||
ck::type_convert<ck::index_t>(input_spatial_lengths.size()) != num_dim_spatial ||
ck::type_convert<ck::index_t>(conv_filter_strides.size()) != num_dim_spatial ||
ck::type_convert<ck::index_t>(conv_filter_dilations.size()) != num_dim_spatial ||
ck::type_convert<ck::index_t>(input_left_pads.size()) != num_dim_spatial ||
ck::type_convert<ck::index_t>(input_right_pads.size()) != num_dim_spatial)
{
throw(std::runtime_error(
"ConvParams::GetOutputSpatialLengths: "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment