Commit edc23800 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the data type for lens and strides from size_t to int in the shape class

parent c7419a9c
......@@ -23,7 +23,7 @@ struct parse_pooling : op_parser<parse_pooling>
if(contains(info.attributes, "strides"))
{
std::vector<size_t> stride;
std::vector<int> stride;
copy(info.attributes.at("strides").list().i(), std::back_inserter(stride));
parser.reorder_data(stride);
if(stride.size() != 4)
......@@ -35,7 +35,7 @@ struct parse_pooling : op_parser<parse_pooling>
}
if(contains(info.attributes, "ksize"))
{
std::vector<size_t> ksize;
std::vector<int> ksize;
copy(info.attributes.at("ksize").list().i(), std::back_inserter(ksize));
parser.reorder_data(ksize);
if(ksize.size() != 4)
......@@ -57,7 +57,7 @@ struct parse_pooling : op_parser<parse_pooling>
calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]);
calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]);
op.padding = std::vector<size_t>(pads.begin(), pads.end());
op.padding = std::vector<int>(pads.begin(), pads.end());
}
}
return info.add_instruction(op, l0);
......
......@@ -19,9 +19,9 @@ struct parse_shape : op_parser<parse_shape>
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int> arg_shape = args[0]->get_shape().lens();
std::vector<int32_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()});
migraphx::shape s(migraphx::shape::int32_type, {static_cast<int>(arg_shape.size())});
std::transform(
arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; });
return info.add_literal(migraphx::literal{s, vec_shape});
......
......@@ -20,8 +20,8 @@ struct parse_strideslice : op_parser<parse_strideslice>
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
auto l0 = args[0];
size_t num_axes = l0->get_shape().lens().size();
std::vector<size_t> axes = l0->get_shape().lens();
int num_axes = l0->get_shape().lens().size();
std::vector<int> axes = l0->get_shape().lens();
std::vector<int64_t> op_starts(starts.begin(), starts.end());
std::vector<int64_t> op_ends(ends.begin(), ends.end());
......@@ -45,7 +45,7 @@ struct parse_strideslice : op_parser<parse_strideslice>
std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
std::vector<int64_t> end_axes = get_axes_from_mask(num_axes, end_mask);
for(size_t i = 0; i < num_axes; i++)
for(int i = 0; i < num_axes; i++)
{
if(begin_axes.at(i) == 1)
{
......@@ -62,7 +62,7 @@ struct parse_strideslice : op_parser<parse_strideslice>
if(shrink_axis_mask == 0)
return l1;
for(size_t i = 0; i < num_axes; i++)
for(int i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to squeeze
if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment