Commit edc23800 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the data type for lens and strides from size_t to int in the shape class

parent c7419a9c
...@@ -23,7 +23,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -23,7 +23,7 @@ struct parse_pooling : op_parser<parse_pooling>
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
std::vector<size_t> stride; std::vector<int> stride;
copy(info.attributes.at("strides").list().i(), std::back_inserter(stride)); copy(info.attributes.at("strides").list().i(), std::back_inserter(stride));
parser.reorder_data(stride); parser.reorder_data(stride);
if(stride.size() != 4) if(stride.size() != 4)
...@@ -35,7 +35,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -35,7 +35,7 @@ struct parse_pooling : op_parser<parse_pooling>
} }
if(contains(info.attributes, "ksize")) if(contains(info.attributes, "ksize"))
{ {
std::vector<size_t> ksize; std::vector<int> ksize;
copy(info.attributes.at("ksize").list().i(), std::back_inserter(ksize)); copy(info.attributes.at("ksize").list().i(), std::back_inserter(ksize));
parser.reorder_data(ksize); parser.reorder_data(ksize);
if(ksize.size() != 4) if(ksize.size() != 4)
...@@ -57,7 +57,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -57,7 +57,7 @@ struct parse_pooling : op_parser<parse_pooling>
calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]); calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]);
calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]); calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]);
op.padding = std::vector<size_t>(pads.begin(), pads.end()); op.padding = std::vector<int>(pads.begin(), pads.end());
} }
} }
return info.add_instruction(op, l0); return info.add_instruction(op, l0);
......
...@@ -19,9 +19,9 @@ struct parse_shape : op_parser<parse_shape> ...@@ -19,9 +19,9 @@ struct parse_shape : op_parser<parse_shape>
const tf_parser::node_info& info, const tf_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens(); std::vector<int> arg_shape = args[0]->get_shape().lens();
std::vector<int32_t> vec_shape(arg_shape.size()); std::vector<int32_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()}); migraphx::shape s(migraphx::shape::int32_type, {static_cast<int>(arg_shape.size())});
std::transform( std::transform(
arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; }); arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; });
return info.add_literal(migraphx::literal{s, vec_shape}); return info.add_literal(migraphx::literal{s, vec_shape});
......
...@@ -20,8 +20,8 @@ struct parse_strideslice : op_parser<parse_strideslice> ...@@ -20,8 +20,8 @@ struct parse_strideslice : op_parser<parse_strideslice>
auto starts = args[1]->eval().get<int32_t>().to_vector(); auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector(); auto ends = args[2]->eval().get<int32_t>().to_vector();
auto l0 = args[0]; auto l0 = args[0];
size_t num_axes = l0->get_shape().lens().size(); int num_axes = l0->get_shape().lens().size();
std::vector<size_t> axes = l0->get_shape().lens(); std::vector<int> axes = l0->get_shape().lens();
std::vector<int64_t> op_starts(starts.begin(), starts.end()); std::vector<int64_t> op_starts(starts.begin(), starts.end());
std::vector<int64_t> op_ends(ends.begin(), ends.end()); std::vector<int64_t> op_ends(ends.begin(), ends.end());
...@@ -45,7 +45,7 @@ struct parse_strideslice : op_parser<parse_strideslice> ...@@ -45,7 +45,7 @@ struct parse_strideslice : op_parser<parse_strideslice>
std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask); std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
std::vector<int64_t> end_axes = get_axes_from_mask(num_axes, end_mask); std::vector<int64_t> end_axes = get_axes_from_mask(num_axes, end_mask);
for(size_t i = 0; i < num_axes; i++) for(int i = 0; i < num_axes; i++)
{ {
if(begin_axes.at(i) == 1) if(begin_axes.at(i) == 1)
{ {
...@@ -62,7 +62,7 @@ struct parse_strideslice : op_parser<parse_strideslice> ...@@ -62,7 +62,7 @@ struct parse_strideslice : op_parser<parse_strideslice>
if(shrink_axis_mask == 0) if(shrink_axis_mask == 0)
return l1; return l1;
for(size_t i = 0; i < num_axes; i++) for(int i = 0; i < num_axes; i++)
{ {
// the LSB corresponds to axis 0 when determining which axes to squeeze // the LSB corresponds to axis 0 when determining which axes to squeeze
if(((shrink_axis_mask >> i) & bitwise_compare) == 1) if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment