Commit edc23800 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the data type for lens and strides from size_t to int in the shape class

parent c7419a9c
...@@ -33,7 +33,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar> ...@@ -33,7 +33,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar>
auto input_type = input_shape.type(); auto input_type = input_shape.type();
auto scale_val = info.add_literal(literal{shape{input_type}, {scale}}); auto scale_val = info.add_literal(literal{shape{input_type}, {scale}});
auto bias_vals = info.add_literal(literal{shape{input_type, {bias.size()}}, bias}); auto bias_vals = info.add_literal(literal{shape{input_type, {static_cast<int>(bias.size())}}, bias});
auto scale_tensor = info.add_instruction( auto scale_tensor = info.add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", input_lens}}), scale_val); migraphx::make_op("scalar", {{"scalar_bcst_dims", input_lens}}), scale_val);
......
...@@ -47,9 +47,9 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -47,9 +47,9 @@ struct parse_matmul : op_parser<parse_matmul>
if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend())) if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
{ {
auto l0_it = l0_lens.begin() + l0_lens.size() - 2; auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it); std::vector<int> l0_broadcasted_lens(l0_lens.begin(), l0_it);
auto l1_it = l1_lens.begin() + l1_lens.size() - 2; auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it); std::vector<int> l1_broadcasted_lens(l1_lens.begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens); auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens = output_lens; l0_broadcasted_lens = output_lens;
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end()); l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
......
...@@ -23,7 +23,7 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -23,7 +23,7 @@ struct parse_multinomial : op_parser<parse_multinomial>
dtype = info.attributes.at("dtype").i(); dtype = info.attributes.at("dtype").i();
shape::type_t output_type = get_type(dtype); shape::type_t output_type = get_type(dtype);
size_t sample_size = 1; int sample_size = 1;
if(contains(info.attributes, "sample_size")) if(contains(info.attributes, "sample_size"))
sample_size = info.attributes.at("sample_size").i(); sample_size = info.attributes.at("sample_size").i();
...@@ -46,7 +46,7 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -46,7 +46,7 @@ struct parse_multinomial : op_parser<parse_multinomial>
gen.seed(info.attributes.at("seed").f()); gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
size_t batch_size = args[0]->get_shape().lens().front(); int batch_size = args[0]->get_shape().lens().front();
migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}}; migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}};
std::vector<float> random_dist(batch_size * sample_size); std::vector<float> random_dist(batch_size * sample_size);
......
...@@ -9,10 +9,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -9,10 +9,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
template <class T> template <class T>
static std::vector<std::size_t> nonzero_indices(const std::vector<T>& data) static std::vector<int> nonzero_indices(const std::vector<T>& data)
{ {
std::vector<std::size_t> indices; std::vector<int> indices;
for(std::size_t i = 0; i < data.size(); ++i) for(int i = 0; i < data.size(); ++i)
{ {
if(!float_equal(data[i], 0)) if(!float_equal(data[i], 0))
indices.push_back(i); indices.push_back(i);
...@@ -37,7 +37,7 @@ struct parse_nonzero : op_parser<parse_nonzero> ...@@ -37,7 +37,7 @@ struct parse_nonzero : op_parser<parse_nonzero>
} }
else else
{ {
std::vector<std::size_t> indices; std::vector<int> indices;
data_arg.visit([&](auto val) { data_arg.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>; using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
std::vector<val_type> vec_data; std::vector<val_type> vec_data;
...@@ -46,13 +46,13 @@ struct parse_nonzero : op_parser<parse_nonzero> ...@@ -46,13 +46,13 @@ struct parse_nonzero : op_parser<parse_nonzero>
}); });
shape in_s = args[0]->get_shape(); shape in_s = args[0]->get_shape();
shape out_s{shape::int64_type, {in_s.lens().size(), indices.size()}}; shape out_s{shape::int64_type, {static_cast<int>(in_s.lens().size()), static_cast<int>(indices.size())}};
std::vector<int64_t> out_data(out_s.elements()); std::vector<int64_t> out_data(out_s.elements());
for(std::size_t i = 0; i < indices.size(); ++i) for(int i = 0; i < indices.size(); ++i)
{ {
auto idx = in_s.multi(indices[i]); auto idx = in_s.multi(indices[i]);
for(std::size_t j = 0; j < in_s.lens().size(); ++j) for(int j = 0; j < in_s.lens().size(); ++j)
{ {
out_data[out_s.index({j, i})] = idx[j]; out_data[out_s.index({j, i})] = idx[j];
} }
......
...@@ -20,7 +20,7 @@ struct parse_onehot : op_parser<parse_onehot> ...@@ -20,7 +20,7 @@ struct parse_onehot : op_parser<parse_onehot>
{ {
migraphx::argument depth_arg = args[1]->eval(); migraphx::argument depth_arg = args[1]->eval();
check_arg_empty(depth_arg, "PARSE_ONEHOT: depth - dynamic shape not supported"); check_arg_empty(depth_arg, "PARSE_ONEHOT: depth - dynamic shape not supported");
size_t depth = depth_arg.at<size_t>(); int depth = depth_arg.at<int>();
int64_t axis = -1; int64_t axis = -1;
if(contains(info.attributes, "axis")) if(contains(info.attributes, "axis"))
......
...@@ -32,7 +32,7 @@ instruction_ref reflect_pad(const onnx_parser::node_info& info, ...@@ -32,7 +32,7 @@ instruction_ref reflect_pad(const onnx_parser::node_info& info,
const std::vector<int64_t>& pads, const std::vector<int64_t>& pads,
instruction_ref input) instruction_ref input)
{ {
size_t num_dims = pads.size() / 2; int num_dims = pads.size() / 2;
std::vector<int> ldims(pads.begin(), pads.begin() + num_dims); std::vector<int> ldims(pads.begin(), pads.begin() + num_dims);
std::vector<int> rdims(pads.begin() + num_dims, pads.end()); std::vector<int> rdims(pads.begin() + num_dims, pads.end());
assert(ldims.size() == rdims.size()); assert(ldims.size() == rdims.size());
...@@ -50,7 +50,7 @@ instruction_ref reflect_pad(const onnx_parser::node_info& info, ...@@ -50,7 +50,7 @@ instruction_ref reflect_pad(const onnx_parser::node_info& info,
continue; continue;
// calculate starts and ends for each iteration since shape may change // calculate starts and ends for each iteration since shape may change
std::vector<size_t> dims = input->get_shape().lens(); std::vector<int> dims = input->get_shape().lens();
std::vector<int64_t> starts(axes.size(), 0); std::vector<int64_t> starts(axes.size(), 0);
std::vector<int64_t> ends(dims.begin(), dims.end()); std::vector<int64_t> ends(dims.begin(), dims.end());
std::vector<instruction_ref> slices; std::vector<instruction_ref> slices;
......
...@@ -36,7 +36,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -36,7 +36,7 @@ struct parse_pooling : op_parser<parse_pooling>
if(starts_with(opd.onnx_name, "Global")) if(starts_with(opd.onnx_name, "Global"))
{ {
values["lengths"] = std::vector<size_t>(in_lens.begin() + 2, in_lens.end()); values["lengths"] = std::vector<int>(in_lens.begin() + 2, in_lens.end());
} }
// does not support ceil_mode // does not support ceil_mode
...@@ -86,7 +86,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -86,7 +86,7 @@ struct parse_pooling : op_parser<parse_pooling>
// return paddings could be empty, then setting to 0 for no padding // return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size(info, cal_auto_padding_size(info,
values, values,
values["lengths"].to_vector<std::size_t>(), values["lengths"].to_vector<int>(),
{1, 1}, {1, 1},
in_lens, in_lens,
paddings); paddings);
...@@ -133,7 +133,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -133,7 +133,7 @@ struct parse_pooling : op_parser<parse_pooling>
slice_end.begin(), slice_end.begin(),
[](auto i, auto j) { return i + j; }); [](auto i, auto j) { return i + j; });
} }
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end()); values["padding"] = std::vector<int>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val); check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
op.from_value(values); op.from_value(values);
......
...@@ -20,9 +20,9 @@ struct parse_shape : op_parser<parse_shape> ...@@ -20,9 +20,9 @@ struct parse_shape : op_parser<parse_shape>
{ {
if(args.size() != 1) if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand"); MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens(); std::vector<int> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size()); std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()}); migraphx::shape s(migraphx::shape::int64_type, {static_cast<int>(arg_shape.size())});
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i); return int64_t(i);
}); });
......
diff --git a/src/include/migraphx/op/capture.hpp b/src/include/migraphx/op/capture.hpp
index f33eab9bb..80ffcbe6b 100644
--- a/src/include/migraphx/op/capture.hpp
+++ b/src/include/migraphx/op/capture.hpp
@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
+#include <migraphx/context.hpp>
#include <cmath>
#include <utility>
@@ -29,7 +30,9 @@ struct capture
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
- argument compute(const shape&, std::vector<argument> args) const
+ argument compute(const shape&, std::vector<argument> args) const { return args.front(); }
+
+ argument compute(context&, const shape&, const std::vector<argument>& args) const
{
if(f)
{
diff --git a/src/include/migraphx/operation.hpp b/src/include/migraphx/operation.hpp
index 922eabd67..56108a871 100644
--- a/src/include/migraphx/operation.hpp
+++ b/src/include/migraphx/operation.hpp
@@ -271,25 +271,25 @@ auto compute_op(rank<3>,
template <class T, class F>
auto compute_op(rank<2>,
const T& x,
- context&,
+ context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
- F) -> decltype(x.compute(output, inputs))
+ F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
- return x.compute(output, inputs);
+ return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
- context& ctx,
+ context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
- F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
+ F) -> decltype(x.compute(output, inputs))
{
- return x.compute(auto_any_cast(ctx), output, inputs);
+ return x.compute(output, inputs);
}
template <class T, class F>
diff --git a/tools/include/operation.hpp b/tools/include/operation.hpp
index 0c49edfaf..ef9927cdc 100644
--- a/tools/include/operation.hpp
+++ b/tools/include/operation.hpp
@@ -271,25 +271,25 @@ auto compute_op(rank<3>,
template <class T, class F>
auto compute_op(rank<2>,
const T& x,
- context&,
+ context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
- F) -> decltype(x.compute(output, inputs))
+ F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
- return x.compute(output, inputs);
+ return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
- context& ctx,
+ context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
- F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
+ F) -> decltype(x.compute(output, inputs))
{
- return x.compute(auto_any_cast(ctx), output, inputs);
+ return x.compute(output, inputs);
}
template <class T, class F>
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool reduce_dim(std::vector<shape>& shapes, std::size_t n) bool reduce_dim(std::vector<shape>& shapes, int n)
{ {
std::vector<std::size_t> new_lens; std::vector<int> new_lens;
for(const auto& s : shapes) for(const auto& s : shapes)
{ {
assert(n < s.lens().size()); assert(n < s.lens().size());
...@@ -23,7 +23,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n) ...@@ -23,7 +23,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
} }
if(new_lens.size() != shapes.size()) if(new_lens.size() != shapes.size())
return false; return false;
std::size_t i = 0; int i = 0;
for(auto& s : shapes) for(auto& s : shapes)
{ {
auto lens = s.lens(); auto lens = s.lens();
...@@ -37,7 +37,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n) ...@@ -37,7 +37,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
return true; return true;
} }
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n) int reduce_dim_all(std::vector<shape>& shapes, int n)
{ {
while(reduce_dim(shapes, n) and n < shapes.size()) while(reduce_dim(shapes, n) and n < shapes.size())
{ {
...@@ -47,16 +47,16 @@ std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n) ...@@ -47,16 +47,16 @@ std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
} }
void reduce_dim_all(std::vector<shape>& shapes) void reduce_dim_all(std::vector<shape>& shapes)
{ {
std::size_t n = 0; int n = 0;
while(n < shapes.front().lens().size() - 1) while(n < shapes.front().lens().size() - 1)
n = reduce_dim_all(shapes, n); n = reduce_dim_all(shapes, n);
} }
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes) std::vector<int> base_lens(const std::vector<shape>& shapes)
{ {
return std::accumulate( return std::accumulate(
shapes.begin() + 1, shapes.end(), shapes.front().lens(), [](auto&& lens, auto&& s) { shapes.begin() + 1, shapes.end(), shapes.front().lens(), [](auto&& lens, auto&& s) {
std::vector<std::size_t> result; std::vector<int> result;
const auto* x = &s.lens(); const auto* x = &s.lens();
const auto* y = &lens; const auto* y = &lens;
if(x->size() > y->size()) if(x->size() > y->size())
...@@ -69,12 +69,12 @@ std::vector<std::size_t> base_lens(const std::vector<shape>& shapes) ...@@ -69,12 +69,12 @@ std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
}); });
} }
shape mask_shape(const shape& s, const std::vector<std::size_t>& lens) shape mask_shape(const shape& s, const std::vector<int>& lens)
{ {
assert(s.lens().size() == lens.size()); assert(s.lens().size() == lens.size());
std::vector<std::size_t> rstrides(lens.size()); std::vector<int> rstrides(lens.size());
std::size_t stride = 1; int stride = 1;
for(std::size_t i = lens.size() - 1; i < lens.size(); i--) for(int i = lens.size() - 1; i < lens.size(); i--)
{ {
if(lens[i] == s.lens()[i]) if(lens[i] == s.lens()[i])
{ {
......
...@@ -28,7 +28,7 @@ void rewrite_batchnorm::apply(module& p) const ...@@ -28,7 +28,7 @@ void rewrite_batchnorm::apply(module& p) const
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); })) if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue; continue;
std::vector<std::size_t> lens = ins->inputs()[1]->get_shape().lens(); std::vector<int> lens = ins->inputs()[1]->get_shape().lens();
shape s{ins->get_shape().type(), lens}; shape s{ins->get_shape().type(), lens};
// Get epsilon // Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator()); auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
...@@ -39,8 +39,8 @@ void rewrite_batchnorm::apply(module& p) const ...@@ -39,8 +39,8 @@ void rewrite_batchnorm::apply(module& p) const
visit_all(gamma, bias, mean, variance, a, b)( visit_all(gamma, bias, mean, variance, a, b)(
[&](auto gamma2, auto bias2, auto mean2, auto variance2, auto a2, auto b2) { [&](auto gamma2, auto bias2, auto mean2, auto variance2, auto a2, auto b2) {
dfor(a.get_shape().elements())( dfor(a.get_shape().elements())(
[&](std::size_t c) { a2[c] = gamma2[c] / std::sqrt(variance2[c] + epsilon); }); [&](int c) { a2[c] = gamma2[c] / std::sqrt(variance2[c] + epsilon); });
dfor(b.get_shape().elements())([&](std::size_t c) { dfor(b.get_shape().elements())([&](int c) {
b2[c] = bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon)); b2[c] = bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
}); });
}); });
......
...@@ -60,8 +60,8 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const ...@@ -60,8 +60,8 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
auto args = ins->inputs(); auto args = ins->inputs();
shape seq_shape = args[0]->get_shape(); shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1]; int hidden_size = args[1]->get_shape().lens()[1];
std::size_t batch_size = seq_shape.lens()[1]; int batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type(); shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}}; migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0); std::vector<float> data(ih_shape.elements(), 0);
...@@ -369,8 +369,8 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const ...@@ -369,8 +369,8 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
auto args = ins->inputs(); auto args = ins->inputs();
shape seq_shape = args[0]->get_shape(); shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2]; int hidden_size = args[2]->get_shape().lens()[2];
std::size_t batch_size = seq_shape.lens()[1]; int batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type(); shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}}; migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0.0); std::vector<float> data(ih_shape.elements(), 0.0);
...@@ -754,8 +754,8 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -754,8 +754,8 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
auto args = ins->inputs(); auto args = ins->inputs();
shape seq_shape = args[0]->get_shape(); shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2]; int hidden_size = args[2]->get_shape().lens()[2];
std::size_t batch_size = seq_shape.lens()[1]; int batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type(); shape::type_t type = seq_shape.type();
migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}}; migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
std::vector<float> ihc_data(ihc_shape.elements(), 0.0); std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
...@@ -1195,7 +1195,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1195,7 +1195,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
// specifiy any actv func. If less than 46, use the // specifiy any actv func. If less than 46, use the
// algorithm in parse_lstm to make 6 actv functions // algorithm in parse_lstm to make 6 actv functions
const auto& actv_funcs = lstm_op.actv_funcs; const auto& actv_funcs = lstm_op.actv_funcs;
std::size_t num_actv_funcs = actv_funcs.size(); int num_actv_funcs = actv_funcs.size();
if(lstm_op.direction == op::rnn_direction::bidirectional) if(lstm_op.direction == op::rnn_direction::bidirectional)
{ {
switch(num_actv_funcs) switch(num_actv_funcs)
...@@ -1295,7 +1295,7 @@ bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_l ...@@ -1295,7 +1295,7 @@ bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_l
return is_var_lens; return is_var_lens;
} }
std::size_t int
rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const
{ {
bool is_var_lens = is_variable_seq_lens(prog, seq_lens); bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
...@@ -1304,7 +1304,7 @@ rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ ...@@ -1304,7 +1304,7 @@ rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_
if(!is_var_lens and seq_lens != prog.end()) if(!is_var_lens and seq_lens != prog.end())
{ {
auto arg_len = seq_lens->eval(); auto arg_len = seq_lens->eval();
std::vector<std::size_t> vec_lens; std::vector<int> vec_lens;
arg_len.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); }); arg_len.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
length = vec_lens.empty() ? length : vec_lens[0]; length = vec_lens.empty() ? length : vec_lens[0];
} }
...@@ -1414,7 +1414,7 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog, ...@@ -1414,7 +1414,7 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
{ {
auto s = hs->get_shape(); auto s = hs->get_shape();
auto pad_lens = s.lens(); auto pad_lens = s.lens();
pad_lens[0] = static_cast<std::size_t>(max_seq_len - seq_len); pad_lens[0] = static_cast<int>(max_seq_len - seq_len);
shape pad_s{s.type(), pad_lens}; shape pad_s{s.type(), pad_lens};
std::vector<float> pad_data(pad_s.elements(), 0.0f); std::vector<float> pad_data(pad_s.elements(), 0.0f);
auto pl = prog.add_literal(pad_s, pad_data.begin(), pad_data.end()); auto pl = prog.add_literal(pad_s, pad_data.begin(), pad_data.end());
......
...@@ -26,14 +26,14 @@ struct shape_impl ...@@ -26,14 +26,14 @@ struct shape_impl
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
} }
shape_impl(shape::type_t t, std::vector<std::size_t> l) shape_impl(shape::type_t t, std::vector<int> l)
: m_type(t), m_lens(std::move(l)), m_standard(true) : m_type(t), m_lens(std::move(l)), m_standard(true)
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
this->calculate_strides(); this->calculate_strides();
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
} }
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape_impl(shape::type_t t, std::vector<int> l, std::vector<int> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s)) : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
...@@ -46,8 +46,8 @@ struct shape_impl ...@@ -46,8 +46,8 @@ struct shape_impl
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {} shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type; shape::type_t m_type;
std::vector<std::size_t> m_lens = {}; std::vector<int> m_lens = {};
std::vector<std::size_t> m_strides = {}; std::vector<int> m_strides = {};
std::vector<shape> m_shapes = {}; std::vector<shape> m_shapes = {};
bool m_standard = false; bool m_standard = false;
...@@ -61,10 +61,10 @@ struct shape_impl ...@@ -61,10 +61,10 @@ struct shape_impl
std::partial_sum(m_lens.rbegin(), std::partial_sum(m_lens.rbegin(),
m_lens.rend() - 1, m_lens.rend() - 1,
m_strides.rbegin() + 1, m_strides.rbegin() + 1,
std::multiplies<std::size_t>()); std::multiplies<int>());
} }
std::size_t element_space() const int element_space() const
{ {
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
if(m_lens.empty()) if(m_lens.empty())
...@@ -72,19 +72,19 @@ struct shape_impl ...@@ -72,19 +72,19 @@ struct shape_impl
return std::inner_product(m_lens.begin(), return std::inner_product(m_lens.begin(),
m_lens.end(), m_lens.end(),
m_strides.begin(), m_strides.begin(),
std::size_t{0}, int{0},
std::plus<std::size_t>{}, std::plus<int>{},
[](std::size_t l, std::size_t s) { return (l - 1) * s; }) + [](int l, int s) { return (l - 1) * s; }) +
1; 1;
} }
std::size_t elements() const int elements() const
{ {
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
if(m_lens.empty()) if(m_lens.empty())
return 0; return 0;
return std::accumulate( return std::accumulate(
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), int{1}, std::multiplies<int>());
} }
}; };
...@@ -124,11 +124,11 @@ std::string shape::cpp_type(shape::type_t t) ...@@ -124,11 +124,11 @@ std::string shape::cpp_type(shape::type_t t)
shape::shape() : impl(shape_impl::default_shape()) {} shape::shape() : impl(shape_impl::default_shape()) {}
shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {} shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
shape::shape(type_t t, std::vector<std::size_t> l) shape::shape(type_t t, std::vector<int> l)
: impl(std::make_shared<shape_impl>(t, std::move(l))) : impl(std::make_shared<shape_impl>(t, std::move(l)))
{ {
} }
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape::shape(type_t t, std::vector<int> l, std::vector<int> s)
: impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s))) : impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
{ {
} }
...@@ -136,7 +136,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -136,7 +136,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {} shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape shape::from_permutation(type_t t, shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l, const std::vector<int>& l,
const std::vector<int64_t>& perm) const std::vector<int64_t>& perm)
{ {
auto new_lens = reorder_dims(l, perm); auto new_lens = reorder_dims(l, perm);
...@@ -146,14 +146,14 @@ shape shape::from_permutation(type_t t, ...@@ -146,14 +146,14 @@ shape shape::from_permutation(type_t t,
} }
shape::type_t shape::type() const { return impl->m_type; } shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; } const std::vector<int>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; } const std::vector<int>& shape::strides() const { return impl->m_strides; }
std::size_t shape::elements() const { return impl->elements(); } int shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const int shape::bytes() const
{ {
if(this->sub_shapes().empty()) if(this->sub_shapes().empty())
{ {
std::size_t n = 0; int n = 0;
this->visit_type([&](auto as) { n = as.size(); }); this->visit_type([&](auto as) { n = as.size(); });
return n * this->element_space(); return n * this->element_space();
} }
...@@ -161,44 +161,44 @@ std::size_t shape::bytes() const ...@@ -161,44 +161,44 @@ std::size_t shape::bytes() const
{ {
return std::accumulate(this->sub_shapes().begin(), return std::accumulate(this->sub_shapes().begin(),
this->sub_shapes().end(), this->sub_shapes().end(),
std::size_t{0}, int{0},
[&](auto x, auto y) { return x + y.bytes(); }); [&](auto x, auto y) { return x + y.bytes(); });
} }
} }
std::size_t shape::type_size() const int shape::type_size() const
{ {
std::size_t n = 0; int n = 0;
if(this->sub_shapes().empty()) if(this->sub_shapes().empty())
this->visit_type([&](auto as) { n = as.size(); }); this->visit_type([&](auto as) { n = as.size(); });
return n; return n;
} }
std::size_t shape::index(std::initializer_list<std::size_t> l) const int shape::index(std::initializer_list<int> l) const
{ {
assert(l.size() <= this->lens().size()); assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0}); return std::inner_product(l.begin(), l.end(), this->strides().begin(), int{0});
} }
std::size_t shape::index(const std::vector<std::size_t>& l) const int shape::index(const std::vector<int>& l) const
{ {
assert(l.size() <= this->lens().size()); assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0}); return std::inner_product(l.begin(), l.end(), this->strides().begin(), int{0});
} }
std::size_t shape::index(std::size_t i) const int shape::index(int i) const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if(this->standard()) if(this->standard())
return i; return i;
else else
{ {
std::size_t s = 1; int s = 1;
std::size_t result = 0; int result = 0;
for(std::size_t j = 0; j < this->lens().size(); j++) for(int j = 0; j < this->lens().size(); j++)
{ {
const std::size_t k = this->lens().size() - j - 1; const int k = this->lens().size() - j - 1;
const std::size_t stride = this->strides()[k]; const int stride = this->strides()[k];
const std::size_t len = this->lens()[k]; const int len = this->lens()[k];
const std::size_t idx = (i % (s * len)) / s; const int idx = (i % (s * len)) / s;
result += stride * idx; result += stride * idx;
s *= len; s *= len;
} }
...@@ -206,17 +206,17 @@ std::size_t shape::index(std::size_t i) const ...@@ -206,17 +206,17 @@ std::size_t shape::index(std::size_t i) const
} }
} }
std::vector<std::size_t> shape::multi(std::size_t i) const std::vector<int> shape::multi(int i) const
{ {
assert(this->standard()); assert(this->standard());
std::vector<std::size_t> indices(lens().size()); std::vector<int> indices(lens().size());
multi_copy(i, indices.data(), indices.data() + lens().size()); multi_copy(i, indices.data(), indices.data() + lens().size());
return indices; return indices;
} }
void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const void shape::multi_copy(int i, int* start, const int* end) const
{ {
assert(this->standard()); assert(this->standard());
(void)end; (void)end;
...@@ -225,7 +225,7 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end ...@@ -225,7 +225,7 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
strides().end(), strides().end(),
lens().begin(), lens().begin(),
start, start,
[&](std::size_t stride, std::size_t len) { [&](int stride, int len) {
assert(len > 0 and stride > 0); assert(len > 0 and stride > 0);
return (i / stride) % len; return (i / stride) % len;
}); });
...@@ -241,12 +241,12 @@ bool shape::transposed() const ...@@ -241,12 +241,12 @@ bool shape::transposed() const
if(this->broadcasted()) if(this->broadcasted())
{ {
// TODO: Use a filter_iterator instead // TODO: Use a filter_iterator instead
std::vector<std::size_t> s; std::vector<int> s;
s.reserve(this->strides().size()); s.reserve(this->strides().size());
std::copy_if(this->strides().begin(), std::copy_if(this->strides().begin(),
this->strides().end(), this->strides().end(),
std::back_inserter(s), std::back_inserter(s),
[](std::size_t x) { return x != 0; }); [](int x) { return x != 0; });
return not std::is_sorted(s.rbegin(), s.rend()); return not std::is_sorted(s.rbegin(), s.rend());
} }
else else
...@@ -260,8 +260,8 @@ bool shape::broadcasted() const ...@@ -260,8 +260,8 @@ bool shape::broadcasted() const
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::accumulate(this->strides().begin(), return std::accumulate(this->strides().begin(),
this->strides().end(), this->strides().end(),
std::size_t{1}, int{1},
std::multiplies<std::size_t>()) == 0; std::multiplies<int>()) == 0;
} }
bool shape::scalar() const bool shape::scalar() const
...@@ -269,7 +269,7 @@ bool shape::scalar() const ...@@ -269,7 +269,7 @@ bool shape::scalar() const
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
// if any stride > 0, then accumulate will return false // if any stride > 0, then accumulate will return false
return this->sub_shapes().empty() and return this->sub_shapes().empty() and
std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0; std::accumulate(this->strides().begin(), this->strides().end(), int(0)) == 0;
} }
bool shape::standard() const { return impl->m_standard; } bool shape::standard() const { return impl->m_standard; }
...@@ -282,19 +282,19 @@ shape shape::normalize_standard() const ...@@ -282,19 +282,19 @@ shape shape::normalize_standard() const
return *this; return *this;
} }
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const shape shape::with_lens(type_t t, const std::vector<int>& l) const
{ {
assert(l.size() == this->lens().size()); assert(l.size() == this->lens().size());
auto perm = find_permutation(*this); auto perm = find_permutation(*this);
return shape::from_permutation(t, l, perm); return shape::from_permutation(t, l, perm);
} }
shape shape::with_lens(const std::vector<std::size_t>& l) const shape shape::with_lens(const std::vector<int>& l) const
{ {
return this->with_lens(this->type(), l); return this->with_lens(this->type(), l);
} }
std::size_t shape::element_space() const { return impl->element_space(); } int shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); } std::string shape::type_string() const { return name(this->type()); }
...@@ -351,8 +351,8 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -351,8 +351,8 @@ void migraphx_from_value(const value& v, shape& s)
else else
{ {
s = shape{shape::parse_type(t), s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(), v.at("lens").to_vector<int>(),
v.at("strides").to_vector<std::size_t>()}; v.at("strides").to_vector<int>()};
} }
} }
......
...@@ -278,10 +278,10 @@ struct find_concat_op ...@@ -278,10 +278,10 @@ struct find_concat_op
} }
template <class Iterator> template <class Iterator>
static std::vector<std::size_t> get_output_lens(Iterator start, Iterator last, std::size_t axis) static std::vector<int> get_output_lens(Iterator start, Iterator last, int axis)
{ {
assert(start != last); assert(start != last);
std::size_t dim = 0; int dim = 0;
for(auto ins : range(start, last)) for(auto ins : range(start, last))
{ {
dim += ins->get_shape().lens().at(axis); dim += ins->get_shape().lens().at(axis);
...@@ -323,7 +323,7 @@ struct find_concat_op ...@@ -323,7 +323,7 @@ struct find_concat_op
} }
std::vector<instruction_ref> concats; std::vector<instruction_ref> concats;
for(std::size_t i = 0; i < x->inputs().size(); i++) for(int i = 0; i < x->inputs().size(); i++)
{ {
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform(start, last, std::back_inserter(inputs), [&](auto j) { std::transform(start, last, std::back_inserter(inputs), [&](auto j) {
...@@ -381,7 +381,7 @@ std::vector<instruction_ref> get_splits(instruction_ref ins) ...@@ -381,7 +381,7 @@ std::vector<instruction_ref> get_splits(instruction_ref ins)
result.begin(), result.end(), [&](auto x, auto y) { return get_end(x) != get_start(y); }); result.begin(), result.end(), [&](auto x, auto y) { return get_end(x) != get_start(y); });
if(it != result.end()) if(it != result.end())
return {}; return {};
for(std::size_t i = 0; i < axes.size(); i++) for(int i = 0; i < axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = axes[i];
if(ins->get_shape().lens()[axis] != get_slice(result.back()).ends[i]) if(ins->get_shape().lens()[axis] != get_slice(result.back()).ends[i])
...@@ -626,16 +626,16 @@ struct find_split_concat ...@@ -626,16 +626,16 @@ struct find_split_concat
} }
}; };
bool axis_equal(const std::vector<std::size_t>& x, bool axis_equal(const std::vector<int>& x,
const std::vector<std::size_t>& y, const std::vector<int>& y,
std::size_t axis) int axis)
{ {
return x.size() == y.size() and x.size() > axis and return x.size() == y.size() and x.size() > axis and
std::equal(x.begin(), x.begin() + axis, y.begin()) and std::equal(x.begin(), x.begin() + axis, y.begin()) and
std::equal(x.begin() + axis + 1, x.end(), y.begin() + axis + 1); std::equal(x.begin() + axis + 1, x.end(), y.begin() + axis + 1);
} }
bool axis_shape_equal(const shape& x, const shape& y, std::size_t axis) bool axis_shape_equal(const shape& x, const shape& y, int axis)
{ {
// TODO: Check strides // TODO: Check strides
return axis_equal(x.lens(), y.lens(), axis); return axis_equal(x.lens(), y.lens(), axis);
...@@ -654,7 +654,7 @@ struct find_add_convs ...@@ -654,7 +654,7 @@ struct find_add_convs
return op.stride[0] == op.stride[1]; return op.stride[0] == op.stride[1];
} }
static std::size_t compute_stride_factor(const op::convolution& x, const op::convolution& y) static int compute_stride_factor(const op::convolution& x, const op::convolution& y)
{ {
if(not symmetrical_strides(x)) if(not symmetrical_strides(x))
return 0; return 0;
...@@ -913,7 +913,7 @@ struct find_split_reshape ...@@ -913,7 +913,7 @@ struct find_split_reshape
auto axis = any_cast<op::slice>(slc->get_operator()).axes[0]; auto axis = any_cast<op::slice>(slc->get_operator()).axes[0];
auto slc_lens = slc->get_shape().lens(); auto slc_lens = slc->get_shape().lens();
auto slc_dim_size = std::accumulate( auto slc_dim_size = std::accumulate(
slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>()); slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<int>());
// search the reshape output (standard shape) to decide which axis are // search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size // in its output corresponding to the slc_dim_size
...@@ -942,7 +942,7 @@ struct find_split_reshape ...@@ -942,7 +942,7 @@ struct find_split_reshape
// replace the original reshape with slice // replace the original reshape with slice
int64_t start = 0; int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i) for(int i = 0; i < vec_rsp.size(); ++i)
{ {
p.replace_instruction( p.replace_instruction(
vec_rsp[i], vec_rsp[i],
......
...@@ -174,14 +174,14 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -174,14 +174,14 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
} }
constexpr index_int compute_block_size(index_int n, index_int max_block_size) constexpr index_int compute_block_size(index_int n, index_int max_block_size)
{ {
size_t block_size = 64; int block_size = 64;
while(block_size < max_block_size and block_size < n) while(block_size < max_block_size and block_size < n)
block_size *= 2; block_size *= 2;
return block_size; return block_size;
} }
inline std::vector<index_int> get_reduce_lens(const std::vector<size_t>& input_lens, inline std::vector<index_int> get_reduce_lens(const std::vector<int>& input_lens,
const std::vector<size_t>& output_lens) const std::vector<int>& output_lens)
{ {
std::vector<index_int> reduce_lens; std::vector<index_int> reduce_lens;
std::transform(output_lens.begin(), std::transform(output_lens.begin(),
......
...@@ -16,9 +16,9 @@ static auto make_mat(tensor_view<T> x) ...@@ -16,9 +16,9 @@ static auto make_mat(tensor_view<T> x)
{ {
const auto& s = x.get_shape(); const auto& s = x.get_shape();
// assert(s.lens().size() == 2); // assert(s.lens().size() == 2);
std::size_t n_dims = s.lens().size(); int n_dims = s.lens().size();
std::size_t dim_0 = n_dims - 2; int dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1; int dim_1 = n_dims - 1;
if(s.transposed()) if(s.transposed())
return matrix<T>{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]}; return matrix<T>{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]};
return matrix<T>{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]}; return matrix<T>{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]};
...@@ -66,9 +66,9 @@ template <class T, class F> ...@@ -66,9 +66,9 @@ template <class T, class F>
void migemm_impl( void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::false_type) tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta, std::false_type)
{ {
std::size_t n_dims = cmat.get_shape().lens().size(); int n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2; int dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1; int dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1]; auto k = amat.get_shape().lens()[dim_1];
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
...@@ -93,7 +93,7 @@ void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, ...@@ -93,7 +93,7 @@ void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat,
auto lens = amat.get_shape().lens(); auto lens = amat.get_shape().lens();
bool batch_mul = bool batch_mul =
std::accumulate( std::accumulate(
lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()) == 1; lens.rbegin() + 2, lens.rend(), int{1}, std::multiplies<int>()) == 1;
if(batch_mul) if(batch_mul)
{ {
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{}); migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
......
...@@ -22,7 +22,7 @@ struct parse_conv : op_parser<parse_conv> ...@@ -22,7 +22,7 @@ struct parse_conv : op_parser<parse_conv>
op::convolution op; op::convolution op;
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
std::vector<size_t> stride; std::vector<int> stride;
copy(info.attributes.at("strides").list().i(), std::back_inserter(stride)); copy(info.attributes.at("strides").list().i(), std::back_inserter(stride));
parser.reorder_data(stride); parser.reorder_data(stride);
if(stride.size() != 4) if(stride.size() != 4)
...@@ -34,7 +34,7 @@ struct parse_conv : op_parser<parse_conv> ...@@ -34,7 +34,7 @@ struct parse_conv : op_parser<parse_conv>
} }
if(contains(info.attributes, "dilations")) if(contains(info.attributes, "dilations"))
{ {
std::vector<size_t> dilation; std::vector<int> dilation;
copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation)); copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation));
parser.reorder_data(dilation); parser.reorder_data(dilation);
if(dilation.size() != 4) if(dilation.size() != 4)
...@@ -53,16 +53,16 @@ struct parse_conv : op_parser<parse_conv> ...@@ -53,16 +53,16 @@ struct parse_conv : op_parser<parse_conv>
if(pad_mode.find("SAME") != std::string::npos) if(pad_mode.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::padding_mode_t::same; op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens(); std::vector<int> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2]; int weight_h = weight_dims[2];
size_t weight_w = weight_dims[3]; int weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens(); auto input_dims = l0->get_shape().lens();
std::vector<int64_t> pads(input_dims.size()); std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h); calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w); calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
op.padding = std::vector<size_t>(pads.begin(), pads.end()); op.padding = std::vector<int>(pads.begin(), pads.end());
} }
else if(pad_mode.find("VALID") != std::string::npos) else if(pad_mode.find("VALID") != std::string::npos)
{ {
...@@ -70,7 +70,7 @@ struct parse_conv : op_parser<parse_conv> ...@@ -70,7 +70,7 @@ struct parse_conv : op_parser<parse_conv>
} }
else if(pad_mode.find("EXPLICIT") != std::string::npos) else if(pad_mode.find("EXPLICIT") != std::string::npos)
{ {
std::vector<size_t> padding; std::vector<int> padding;
copy(info.attributes.at("explicit_paddings").list().i(), copy(info.attributes.at("explicit_paddings").list().i(),
std::back_inserter(padding)); std::back_inserter(padding));
if(padding.size() != 4) if(padding.size() != 4)
......
...@@ -20,12 +20,12 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv> ...@@ -20,12 +20,12 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
op::convolution op; op::convolution op;
size_t num_channels = args[0]->get_shape().lens()[1]; int num_channels = args[0]->get_shape().lens()[1];
op.group = num_channels; op.group = num_channels;
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
std::vector<size_t> stride; std::vector<int> stride;
copy(info.attributes.at("strides").list().i(), std::back_inserter(stride)); copy(info.attributes.at("strides").list().i(), std::back_inserter(stride));
parser.reorder_data(stride); parser.reorder_data(stride);
if(stride.size() != 4) if(stride.size() != 4)
...@@ -39,7 +39,7 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv> ...@@ -39,7 +39,7 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
auto weights = parser.to_kcxy(args[1]); auto weights = parser.to_kcxy(args[1]);
if(contains(info.attributes, "dilations")) if(contains(info.attributes, "dilations"))
{ {
std::vector<size_t> dilation; std::vector<int> dilation;
copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation)); copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation));
parser.reorder_data(dilation); parser.reorder_data(dilation);
if(dilation.size() != 4) if(dilation.size() != 4)
...@@ -58,9 +58,9 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv> ...@@ -58,9 +58,9 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
if(pad_mode.find("SAME") != std::string::npos) if(pad_mode.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::padding_mode_t::same; op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens(); std::vector<int> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2]; int weight_h = weight_dims[2];
size_t weight_w = weight_dims[3]; int weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens(); auto input_dims = l0->get_shape().lens();
std::vector<int64_t> pads(input_dims.size()); std::vector<int64_t> pads(input_dims.size());
......
...@@ -17,9 +17,9 @@ struct parse_expanddims : op_parser<parse_expanddims> ...@@ -17,9 +17,9 @@ struct parse_expanddims : op_parser<parse_expanddims>
const tf_parser::node_info& info, const tf_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
std::vector<size_t> input_dims = args[0]->get_shape().lens(); std::vector<int> input_dims = args[0]->get_shape().lens();
std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end()); std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
size_t num_dims = input_dims.size(); int num_dims = input_dims.size();
int32_t dim = args[1]->eval().at<int32_t>(); int32_t dim = args[1]->eval().at<int32_t>();
if(dim < 0) if(dim < 0)
......
...@@ -17,7 +17,7 @@ struct parse_onehot : op_parser<parse_onehot> ...@@ -17,7 +17,7 @@ struct parse_onehot : op_parser<parse_onehot>
tf_parser::node_info info, tf_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>()); int depth = static_cast<int>(args[1]->eval().at<int32_t>());
int64_t axis = -1; int64_t axis = -1;
float on_value = args[2]->eval().at<float>(); float on_value = args[2]->eval().at<float>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment