Commit 94e3a2e4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change size_t to int

parent 26bd92d8
...@@ -45,7 +45,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name, ...@@ -45,7 +45,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name,
} }
else else
{ {
std::size_t n_dim = args.front()->get_shape().lens().size(); int n_dim = args.front()->get_shape().lens().size();
axes.resize(n_dim); axes.resize(n_dim);
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
} }
......
...@@ -11,26 +11,26 @@ namespace onnx { ...@@ -11,26 +11,26 @@ namespace onnx {
const auto& get_nearest_op(const std::string& mode) const auto& get_nearest_op(const std::string& mode)
{ {
using nearest_op = std::function<std::size_t(std::size_t, double)>; using nearest_op = std::function<int(int, double)>;
static std::unordered_map<std::string, nearest_op> const nearest_ops = { static std::unordered_map<std::string, nearest_op> const nearest_ops = {
{"round_prefer_floor", {"round_prefer_floor",
[=](std::size_t d_in, double val) { [=](int d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val)); val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::ceil((val - 0.5))); return static_cast<int>(std::ceil((val - 0.5)));
}}, }},
{"round_prefer_ceil", {"round_prefer_ceil",
[=](std::size_t d_in, double val) { [=](int d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val)); val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::round((val))); return static_cast<int>(std::round((val)));
}}, }},
{"floor", {"floor",
[=](std::size_t d_in, double val) { [=](int d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val)); val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::floor((val))); return static_cast<int>(std::floor((val)));
}}, }},
{"ceil", [=](std::size_t d_in, double val) { {"ceil", [=](int d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val)); val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::ceil((val))); return static_cast<int>(std::ceil((val)));
}}}; }}};
if(!contains(nearest_ops, mode)) if(!contains(nearest_ops, mode))
...@@ -43,23 +43,23 @@ const auto& get_nearest_op(const std::string& mode) ...@@ -43,23 +43,23 @@ const auto& get_nearest_op(const std::string& mode)
const auto& get_original_idx_op(const std::string& mode) const auto& get_original_idx_op(const std::string& mode)
{ {
using original_idx_op = std::function<double(std::size_t, std::size_t, std::size_t, double)>; using original_idx_op = std::function<double(int, int, int, double)>;
static std::unordered_map<std::string, original_idx_op> const idx_ops = { static std::unordered_map<std::string, original_idx_op> const idx_ops = {
{"half_pixel", {"half_pixel",
[=](std::size_t, std::size_t, std::size_t idx, double scale) { [=](int, int, int idx, double scale) {
return (idx + 0.5) / scale - 0.5; return (idx + 0.5) / scale - 0.5;
}}, }},
{"pytorch_half_pixel", {"pytorch_half_pixel",
[=](std::size_t, std::size_t l_out, std::size_t idx, double scale) { [=](int, int l_out, int idx, double scale) {
return l_out > 1 ? (idx + 0.5) / scale - 0.5 : 0.0; return l_out > 1 ? (idx + 0.5) / scale - 0.5 : 0.0;
}}, }},
{"align_corners", {"align_corners",
[=](std::size_t l_in, std::size_t l_out, std::size_t idx, double) { [=](int l_in, int l_out, int idx, double) {
return (l_out == 1) ? 0.0 : (1.0 * idx * (l_in - 1.0) / (l_out - 1.0)); return (l_out == 1) ? 0.0 : (1.0 * idx * (l_in - 1.0) / (l_out - 1.0));
}}, }},
{"asymmetric", {"asymmetric",
[=](std::size_t, std::size_t, std::size_t idx, double scale) { return idx / scale; }}, [=](int, int, int idx, double scale) { return idx / scale; }},
{"tf_half_pixel_for_nn", [=](std::size_t, std::size_t, std::size_t idx, double scale) { {"tf_half_pixel_for_nn", [=](int, int, int idx, double scale) {
return (idx + 0.5) / scale; return (idx + 0.5) / scale;
}}}; }}};
...@@ -72,9 +72,9 @@ const auto& get_original_idx_op(const std::string& mode) ...@@ -72,9 +72,9 @@ const auto& get_original_idx_op(const std::string& mode)
} }
static std::vector<int> static std::vector<int>
calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind, calc_neighbor_points(const std::vector<std::vector<std::vector<int>>>& vvv_ind,
int i_dim, int i_dim,
const std::vector<std::vector<std::size_t>>& vec_dims, const std::vector<std::vector<int>>& vec_dims,
const shape& in_s) const shape& in_s)
{ {
if(i_dim == vvv_ind.size()) if(i_dim == vvv_ind.size())
...@@ -90,8 +90,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v ...@@ -90,8 +90,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
const auto& vv_ind = vvv_ind[i_dim]; const auto& vv_ind = vvv_ind[i_dim];
const auto& vv_lo = vv_ind.at(0); const auto& vv_lo = vv_ind.at(0);
std::vector<std::vector<std::size_t>> vec_dims1; std::vector<std::vector<int>> vec_dims1;
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) for(int start = 0; start < vec_dims.size(); start += vv_lo.size())
{ {
std::transform(vv_lo.begin(), std::transform(vv_lo.begin(),
vv_lo.end(), vv_lo.end(),
...@@ -104,7 +104,7 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v ...@@ -104,7 +104,7 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
} }
const auto& vv_hi = vv_ind.at(1); const auto& vv_hi = vv_ind.at(1);
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) for(int start = 0; start < vec_dims.size(); start += vv_lo.size())
{ {
std::transform(vv_hi.begin(), std::transform(vv_hi.begin(),
vv_hi.end(), vv_hi.end(),
...@@ -191,7 +191,7 @@ struct parse_resize : op_parser<parse_resize> ...@@ -191,7 +191,7 @@ struct parse_resize : op_parser<parse_resize>
auto in_lens = in_s.lens(); auto in_lens = in_s.lens();
// output shape is explicitly specified // output shape is explicitly specified
std::vector<std::size_t> out_lens(in_lens.size()); std::vector<int> out_lens(in_lens.size());
// scale // scale
std::vector<double> vec_scale; std::vector<double> vec_scale;
...@@ -256,14 +256,14 @@ struct parse_resize : op_parser<parse_resize> ...@@ -256,14 +256,14 @@ struct parse_resize : op_parser<parse_resize>
vec_scale.begin(), vec_scale.begin(),
out_lens.begin(), out_lens.begin(),
[&](auto idx, auto scale) { [&](auto idx, auto scale) {
return static_cast<std::size_t>(idx * scale); return static_cast<int>(idx * scale);
}); });
} }
} }
} }
shape out_s{in_s.type(), out_lens}; shape out_s{in_s.type(), out_lens};
std::size_t out_elements = out_s.elements(); int out_elements = out_s.elements();
auto idx_op = get_original_idx_op(coord_trans_mode); auto idx_op = get_original_idx_op(coord_trans_mode);
// reshape input to one-dimension // reshape input to one-dimension
...@@ -299,9 +299,9 @@ struct parse_resize : op_parser<parse_resize> ...@@ -299,9 +299,9 @@ struct parse_resize : op_parser<parse_resize>
auto nearest_ceil = get_nearest_op("ceil"); auto nearest_ceil = get_nearest_op("ceil");
// get the number of dimensions // get the number of dimensions
std::size_t n_dim = out_lens.size(); int n_dim = out_lens.size();
std::vector<std::vector<std::size_t>> vv_ind(2, std::vector<std::size_t>(out_elements)); std::vector<std::vector<int>> vv_ind(2, std::vector<int>(out_elements));
std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind); std::vector<std::vector<std::vector<int>>> vvv_ind(n_dim, vv_ind);
std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements)); std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements));
shape_for_each(out_s, [&](auto idx) { shape_for_each(out_s, [&](auto idx) {
...@@ -316,22 +316,22 @@ struct parse_resize : op_parser<parse_resize> ...@@ -316,22 +316,22 @@ struct parse_resize : op_parser<parse_resize>
} }
}); });
std::vector<std::vector<std::size_t>> vec_dims(out_elements); std::vector<std::vector<int>> vec_dims(out_elements);
auto ind = calc_neighbor_points(vvv_ind, 0, vec_dims, in_s); auto ind = calc_neighbor_points(vvv_ind, 0, vec_dims, in_s);
auto ind_lens = out_lens; auto ind_lens = out_lens;
ind_lens[0] *= (std::size_t{1} << n_dim); ind_lens[0] *= (int{1} << n_dim);
shape ind_s{shape::int32_type, ind_lens}; shape ind_s{shape::int32_type, ind_lens};
auto ins_ind = info.add_literal(literal(ind_s, ind)); auto ins_ind = info.add_literal(literal(ind_s, ind));
auto data = info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind); auto data = info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
auto dim_lens = out_lens; auto dim_lens = out_lens;
dim_lens[0] *= (std::size_t{1} << (n_dim - 1)); dim_lens[0] *= (int{1} << (n_dim - 1));
for(std::size_t i = 0; i < n_dim; ++i) for(int i = 0; i < n_dim; ++i)
{ {
shape dim_s{shape::float_type, dim_lens}; shape dim_s{shape::float_type, dim_lens};
const auto& dim_delta = delta[n_dim - i - 1]; const auto& dim_delta = delta[n_dim - i - 1];
std::vector<float> delta_data; std::vector<float> delta_data;
for(std::size_t j = 0; j < dim_lens[0] / out_lens[0]; ++j) for(int j = 0; j < dim_lens[0] / out_lens[0]; ++j)
{ {
delta_data.insert(delta_data.begin(), dim_delta.begin(), dim_delta.end()); delta_data.insert(delta_data.begin(), dim_delta.begin(), dim_delta.end());
} }
......
...@@ -20,11 +20,11 @@ struct parse_rnn : op_parser<parse_rnn> ...@@ -20,11 +20,11 @@ struct parse_rnn : op_parser<parse_rnn>
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
migraphx::shape input_shape = args[0]->get_shape(); migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1]; int hidden_size = args[1]->get_shape().lens()[1];
if(contains(info.attributes, "hidden_size")) if(contains(info.attributes, "hidden_size"))
{ {
std::size_t hidden_size_att = int hidden_size_att =
parser.parse_value(info.attributes.at("hidden_size")).at<int>(); parser.parse_value(info.attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att) if(hidden_size != hidden_size_att)
{ {
......
...@@ -32,7 +32,7 @@ struct parse_selu : op_parser<parse_selu> ...@@ -32,7 +32,7 @@ struct parse_selu : op_parser<parse_selu>
auto l_alpha = info.add_literal({{type, {1}}, {alpha}}); auto l_alpha = info.add_literal({{type, {1}}, {alpha}});
auto l_gamma = info.add_literal({{type, {1}}, {gamma / 2.0f}}); auto l_gamma = info.add_literal({{type, {1}}, {gamma / 2.0f}});
if(lens != std::vector<std::size_t>{1}) if(lens != std::vector<int>{1})
{ {
l_alpha = l_alpha =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_alpha); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_alpha);
......
...@@ -37,10 +37,10 @@ void memory_coloring_impl::run() ...@@ -37,10 +37,10 @@ void memory_coloring_impl::run()
bool memory_coloring_impl::allocate(interval_ptr interval) bool memory_coloring_impl::allocate(interval_ptr interval)
{ {
shape s = interval->result; shape s = interval->result;
std::size_t size = s.bytes(); int size = s.bytes();
if(size == 0) if(size == 0)
return false; return false;
std::size_t element_size = (s.elements() == 0 ? 4 : (size / s.elements())); int element_size = (s.elements() == 0 ? 4 : (size / s.elements()));
live_range& segment = interval->segment; live_range& segment = interval->segment;
int vn = segment.vn; int vn = segment.vn;
std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue; std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue;
...@@ -72,11 +72,11 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -72,11 +72,11 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
} }
} }
std::size_t offset = 0; int offset = 0;
while(!conflict_queue.empty()) while(!conflict_queue.empty())
{ {
live_range* range = conflict_queue.top(); live_range* range = conflict_queue.top();
std::size_t iter_offset = range->offset; int iter_offset = range->offset;
if(offset > iter_offset) if(offset > iter_offset)
{ {
offset = std::max(offset, iter_offset + range->size); offset = std::max(offset, iter_offset + range->size);
...@@ -105,7 +105,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -105,7 +105,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
void memory_coloring_impl::build() void memory_coloring_impl::build()
{ {
std::size_t num_of_instrs = p_mod->size(); int num_of_instrs = p_mod->size();
if(num_of_instrs == 0) if(num_of_instrs == 0)
return; return;
...@@ -199,7 +199,7 @@ void memory_coloring_impl::build() ...@@ -199,7 +199,7 @@ void memory_coloring_impl::build()
void memory_coloring_impl::rewrite() void memory_coloring_impl::rewrite()
{ {
std::vector<std::size_t> dims; std::vector<int> dims;
dims.push_back((required_bytes + sizeof(float) - 1) / sizeof(float)); dims.push_back((required_bytes + sizeof(float) - 1) / sizeof(float));
shape s = {shape::float_type, dims}; shape s = {shape::float_type, dims};
instruction_ref scratch_param = p_mod->add_parameter("scratch", s); instruction_ref scratch_param = p_mod->add_parameter("scratch", s);
...@@ -215,7 +215,7 @@ void memory_coloring_impl::rewrite() ...@@ -215,7 +215,7 @@ void memory_coloring_impl::rewrite()
if(!unify_literals && interval->is_literal) if(!unify_literals && interval->is_literal)
continue; continue;
std::size_t offset = 0; int offset = 0;
if(interval->get_offset() != invalid_offset) if(interval->get_offset() != invalid_offset)
{ {
offset = interval->get_offset(); offset = interval->get_offset();
......
...@@ -22,15 +22,15 @@ ...@@ -22,15 +22,15 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
static const std::size_t invalid_offset = std::numeric_limits<std::size_t>::max(); static const int invalid_offset = std::numeric_limits<int>::max();
struct live_range struct live_range
{ {
std::size_t begin; // begin point in the instruction stream. int begin; // begin point in the instruction stream.
std::size_t end; // end point in the instruction stream. int end; // end point in the instruction stream.
std::size_t offset; // offset to base pointer of allocated memory trunk. int offset; // offset to base pointer of allocated memory trunk.
std::size_t vn; // value number that identifies this live_range. int vn; // value number that identifies this live_range.
std::size_t size; // size of required memory in bytes int size; // size of required memory in bytes
#ifdef MIGRAPHX_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
void dump(); void dump();
#endif #endif
...@@ -42,9 +42,9 @@ struct live_interval ...@@ -42,9 +42,9 @@ struct live_interval
{ {
} }
void add_use(std::size_t use) { use_points.push_front(use); } void add_use(int use) { use_points.push_front(use); }
std::size_t get_begin() const { return segment.begin; } int get_begin() const { return segment.begin; }
std::size_t get_end() const { return segment.end; } int get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; } long long get_offset() const { return segment.offset; }
#ifdef MIGRAPHX_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
...@@ -52,9 +52,9 @@ struct live_interval ...@@ -52,9 +52,9 @@ struct live_interval
#endif #endif
live_range segment; live_range segment;
std::size_t id = invalid_offset; int id = invalid_offset;
std::list<std::size_t> use_points{}; std::list<int> use_points{};
std::size_t def_point = invalid_offset; int def_point = invalid_offset;
shape result{}; shape result{};
bool is_literal = false; bool is_literal = false;
bool is_live_on_entry = false; bool is_live_on_entry = false;
...@@ -152,7 +152,7 @@ struct memory_coloring_impl ...@@ -152,7 +152,7 @@ struct memory_coloring_impl
int num_of_lives = 0; int num_of_lives = 0;
int max_value_number = -1; int max_value_number = -1;
std::size_t required_bytes = 0; int required_bytes = 0;
// The earliest program point where an live interval ends. // The earliest program point where an live interval ends.
int earliest_end_point = -1; int earliest_end_point = -1;
// The latest program point where an live interval ends. // The latest program point where an live interval ends.
......
...@@ -32,7 +32,7 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes) ...@@ -32,7 +32,7 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
{ {
if(shapes.empty()) if(shapes.empty())
return {}; return {};
std::map<std::vector<int64_t>, std::size_t> count; std::map<std::vector<int64_t>, int> count;
for(auto&& s : shapes) for(auto&& s : shapes)
{ {
if(s.broadcasted()) if(s.broadcasted())
......
...@@ -120,7 +120,7 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const ...@@ -120,7 +120,7 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const
return mm->get_parameter_shapes(); return mm->get_parameter_shapes();
} }
std::size_t program::size() const { return impl->modules.size(); } int program::size() const { return impl->modules.size(); }
std::vector<shape> program::get_output_shapes() const std::vector<shape> program::get_output_shapes() const
{ {
...@@ -565,7 +565,7 @@ void program::from_value(const value& v) ...@@ -565,7 +565,7 @@ void program::from_value(const value& v)
double common_average(const std::vector<double>& v) double common_average(const std::vector<double>& v)
{ {
std::size_t n = v.size() / 4; int n = v.size() / 4;
double total = std::accumulate(v.begin() + n, v.end() - n, 0.0); double total = std::accumulate(v.begin() + n, v.end() - n, 0.0);
return total / std::distance(v.begin() + n, v.end() - n); return total / std::distance(v.begin() + n, v.end() - n);
} }
...@@ -597,9 +597,9 @@ void program::mark(const parameter_map& params, marker&& m) ...@@ -597,9 +597,9 @@ void program::mark(const parameter_map& params, marker&& m)
} }
void program::perf_report(std::ostream& os, void program::perf_report(std::ostream& os,
std::size_t n, int n,
parameter_map params, parameter_map params,
std::size_t batch) const int batch) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
// Run once by itself // Run once by itself
...@@ -608,7 +608,7 @@ void program::perf_report(std::ostream& os, ...@@ -608,7 +608,7 @@ void program::perf_report(std::ostream& os,
// Run and time entire program // Run and time entire program
std::vector<double> total_vec; std::vector<double> total_vec;
total_vec.reserve(n); total_vec.reserve(n);
for(std::size_t i = 0; i < n; i++) for(int i = 0; i < n; i++)
{ {
total_vec.push_back(time<milliseconds>([&] { total_vec.push_back(time<milliseconds>([&] {
eval(params); eval(params);
...@@ -624,7 +624,7 @@ void program::perf_report(std::ostream& os, ...@@ -624,7 +624,7 @@ void program::perf_report(std::ostream& os,
})); }));
// Run and time each instruction // Run and time each instruction
for(std::size_t i = 0; i < n; i++) for(int i = 0; i < n; i++)
{ {
generic_eval(*this, ctx, params, always([&](auto ins, auto f) { generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
argument result; argument result;
...@@ -640,7 +640,7 @@ void program::perf_report(std::ostream& os, ...@@ -640,7 +640,7 @@ void program::perf_report(std::ostream& os,
// Run and time implicit overhead // Run and time implicit overhead
std::vector<double> overhead_vec; std::vector<double> overhead_vec;
overhead_vec.reserve(n); overhead_vec.reserve(n);
for(std::size_t i = 0; i < n; i++) for(int i = 0; i < n; i++)
{ {
overhead_vec.push_back(time<milliseconds>([&] { dry_run(params); })); overhead_vec.push_back(time<milliseconds>([&] { dry_run(params); }));
} }
......
...@@ -171,7 +171,7 @@ py::buffer_info to_buffer_info(T& x) ...@@ -171,7 +171,7 @@ py::buffer_info to_buffer_info(T& x)
migraphx::shape to_shape(const py::buffer_info& info) migraphx::shape to_shape(const py::buffer_info& info)
{ {
migraphx::shape::type_t t; migraphx::shape::type_t t;
std::size_t n = 0; int n = 0;
visit_types([&](auto as) { visit_types([&](auto as) {
if(info.format == py::format_descriptor<decltype(as())>::format() or if(info.format == py::format_descriptor<decltype(as())>::format() or
(info.format == "l" and py::format_descriptor<decltype(as())>::format() == "q") or (info.format == "l" and py::format_descriptor<decltype(as())>::format() == "q") or
...@@ -193,7 +193,7 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -193,7 +193,7 @@ migraphx::shape to_shape(const py::buffer_info& info)
} }
auto strides = info.strides; auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t { std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> int {
return n > 0 ? i / n : 0; return n > 0 ? i / n : 0;
}); });
...@@ -237,7 +237,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -237,7 +237,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("get_shape", &migraphx::argument::get_shape) .def("get_shape", &migraphx::argument::get_shape)
.def("tolist", .def("tolist",
[](migraphx::argument& x) { [](migraphx::argument& x) {
py::list l{x.get_shape().elements()}; py::list l{static_cast<std::size_t>(x.get_shape().elements())};
visit(x, [&](auto data) { l = py::cast(data.to_vector()); }); visit(x, [&](auto data) { l = py::cast(data.to_vector()); });
return l; return l;
}) })
...@@ -306,8 +306,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -306,8 +306,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
m.def("parse_tf", m.def("parse_tf",
[](const std::string& filename, [](const std::string& filename,
bool is_nhwc, bool is_nhwc,
unsigned int batch_size, int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<int>> map_input_dims,
std::vector<std::string> output_names) { std::vector<std::string> output_names) {
return migraphx::parse_tf( return migraphx::parse_tf(
filename, filename,
...@@ -317,13 +317,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -317,13 +317,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("filename"), py::arg("filename"),
py::arg("is_nhwc") = true, py::arg("is_nhwc") = true,
py::arg("batch_size") = 1, py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<int>>(),
py::arg("output_names") = std::vector<std::string>()); py::arg("output_names") = std::vector<std::string>());
m.def("parse_onnx", m.def("parse_onnx",
[](const std::string& filename, [](const std::string& filename,
unsigned int default_dim_value, unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<int>> map_input_dims,
bool skip_unknown_operators, bool skip_unknown_operators,
bool print_program_on_error, bool print_program_on_error,
int64_t max_loop_iterations) { int64_t max_loop_iterations) {
...@@ -338,7 +338,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -338,7 +338,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"Parse onnx file", "Parse onnx file",
py::arg("filename"), py::arg("filename"),
py::arg("default_dim_value") = 1, py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<int>>(),
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false, py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10); py::arg("max_loop_iterations") = 10);
...@@ -346,7 +346,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -346,7 +346,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
m.def("parse_onnx_buffer", m.def("parse_onnx_buffer",
[](const std::string& onnx_buffer, [](const std::string& onnx_buffer,
unsigned int default_dim_value, unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<int>> map_input_dims,
bool skip_unknown_operators, bool skip_unknown_operators,
bool print_program_on_error) { bool print_program_on_error) {
migraphx::onnx_options options; migraphx::onnx_options options;
...@@ -359,7 +359,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -359,7 +359,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"Parse onnx file", "Parse onnx file",
py::arg("filename"), py::arg("filename"),
py::arg("default_dim_value") = 1, py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<int>>(),
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false); py::arg("print_program_on_error") = false);
......
...@@ -57,7 +57,7 @@ void quantize_int8(program& prog, ...@@ -57,7 +57,7 @@ void quantize_int8(program& prog,
std::make_shared<std::vector<std::pair<float, float>>>(); std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>(); std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index, auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](int ins_index,
std::vector<argument> args) { std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f}; std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not // scale and shift is need for only int8 type, and we do not
...@@ -83,7 +83,7 @@ void quantize_int8(program& prog, ...@@ -83,7 +83,7 @@ void quantize_int8(program& prog,
}; };
// pass to add capture argument op // pass to add capture argument op
std::size_t param_num = 0; int param_num = 0;
run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}}); run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}});
int8_quant_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f)); int8_quant_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(param_num, 0.0f); max_abs_vals->resize(param_num, 0.0f);
...@@ -115,7 +115,7 @@ void quantize_int8(program& prog, ...@@ -115,7 +115,7 @@ void quantize_int8(program& prog,
// print the quantization parameters in only the main module // print the quantization parameters in only the main module
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{})) if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
{ {
for(std::size_t i = 0; i < int8_quant_params->size(); ++i) for(int i = 0; i < int8_quant_params->size(); ++i)
{ {
auto param = int8_quant_params->at(i); auto param = int8_quant_params->at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first std::cout << "ins_index = " << i << ", scale = " << param.first
......
...@@ -38,7 +38,7 @@ void quantize_int8_pass::apply(module& m) const // NOLINT ...@@ -38,7 +38,7 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
auto op_val = ins->get_operator().to_value(); auto op_val = ins->get_operator().to_value();
assert(op_val.contains("ins_index")); assert(op_val.contains("ins_index"));
auto param_index = op_val.at("ins_index").to<std::size_t>(); auto param_index = op_val.at("ins_index").to<int>();
auto param = quant_params[param_index]; auto param = quant_params[param_index];
auto input = ins->inputs().front(); auto input = ins->inputs().front();
......
...@@ -579,7 +579,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -579,7 +579,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
// initial states // initial states
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
size_t bs = ih->get_shape().lens()[1]; int bs = ih->get_shape().lens()[1];
// bias // bias
instruction_ref bwb{}; instruction_ref bwb{};
...@@ -592,7 +592,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -592,7 +592,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
bwb = prog.insert_instruction( bwb = prog.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<int>(3 * hs)}}}),
wb); wb);
auto rb_zr = prog.insert_instruction( auto rb_zr = prog.insert_instruction(
...@@ -605,11 +605,11 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -605,11 +605,11 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
sbias); sbias);
brb_zr = prog.insert_instruction( brb_zr = prog.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<int>(2 * hs)}}}),
rb_zr); rb_zr);
brb_h = prog.insert_instruction( brb_h = prog.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<int>(hs)}}}),
rb_h); rb_h);
} }
...@@ -1067,7 +1067,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1067,7 +1067,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
wrb = prog.insert_instruction( wrb = prog.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<int>(hs)}}}),
ub_wrb); ub_wrb);
} }
......
...@@ -37,19 +37,19 @@ auto get_outputs() ...@@ -37,19 +37,19 @@ auto get_outputs()
struct stream_info struct stream_info
{ {
std::unordered_map<instruction_ref, std::size_t> ins2stream; std::unordered_map<instruction_ref, int> ins2stream;
std::unordered_map<instruction_ref, std::size_t> weights; std::unordered_map<instruction_ref, int> weights;
std::unordered_map<instruction_ref, std::size_t> iweights; std::unordered_map<instruction_ref, int> iweights;
ins_dep_map mod_implicit_deps; ins_dep_map mod_implicit_deps;
void calc_implicit_deps(const module& p) { mod_implicit_deps = p.calc_implicit_deps(); } void calc_implicit_deps(const module& p) { mod_implicit_deps = p.calc_implicit_deps(); }
void accumulate_weights(instruction_ref last, const schedule_model& model) void accumulate_weights(instruction_ref last, const schedule_model& model)
{ {
fix<std::size_t>([&](auto self, auto ins) -> std::size_t { fix<int>([&](auto self, auto ins) -> int {
if(not contains(weights, ins)) if(not contains(weights, ins))
{ {
std::size_t weight = 0; int weight = 0;
auto&& op = ins->get_operator(); auto&& op = ins->get_operator();
if(not is_context_free(op) and op.name()[0] != '@') if(not is_context_free(op) and op.name()[0] != '@')
weight = model.weight(op); weight = model.weight(op);
...@@ -65,7 +65,7 @@ struct stream_info ...@@ -65,7 +65,7 @@ struct stream_info
} }
weights[ins] = std::accumulate( weights[ins] = std::accumulate(
inputs.begin(), inputs.end(), weight, [&](std::size_t w, instruction_ref i) { inputs.begin(), inputs.end(), weight, [&](int w, instruction_ref i) {
return w + self(i); return w + self(i);
}); });
} }
...@@ -91,13 +91,13 @@ struct stream_info ...@@ -91,13 +91,13 @@ struct stream_info
return args.end(); return args.end();
} }
const std::size_t min_partition_threshold = 2; const int min_partition_threshold = 2;
sort_args_by_weight(args, std::greater<>{}); sort_args_by_weight(args, std::greater<>{});
auto it = std::lower_bound(std::next(args.begin()), auto it = std::lower_bound(std::next(args.begin()),
args.end(), args.end(),
min_partition_threshold, min_partition_threshold,
[&](auto i, std::size_t w) { return this->weights[i] > w; }); [&](auto i, int w) { return this->weights[i] > w; });
assert(it == args.end() or this->weights[*it] <= min_partition_threshold); assert(it == args.end() or this->weights[*it] <= min_partition_threshold);
assert(it == args.end() or std::prev(it) == args.begin() or assert(it == args.end() or std::prev(it) == args.begin() or
this->weights[*std::prev(it)] > min_partition_threshold); this->weights[*std::prev(it)] > min_partition_threshold);
...@@ -106,17 +106,17 @@ struct stream_info ...@@ -106,17 +106,17 @@ struct stream_info
struct partition struct partition
{ {
std::size_t weight = 0; int weight = 0;
std::vector<instruction_ref> instructions{}; std::vector<instruction_ref> instructions{};
void add(instruction_ref ins, std::size_t w) void add(instruction_ref ins, int w)
{ {
weight += w; weight += w;
instructions.push_back(ins); instructions.push_back(ins);
} }
}; };
std::size_t assign_streams(module& p, std::size_t n) int assign_streams(module& p, int n)
{ {
assert(n > 0); assert(n > 0);
partition critical; partition critical;
...@@ -166,7 +166,7 @@ struct stream_info ...@@ -166,7 +166,7 @@ struct stream_info
} }
else else
{ {
std::vector<std::size_t> streams(n - 1); std::vector<int> streams(n - 1);
// Assign streams for the other partitions // Assign streams for the other partitions
for(auto&& ins_part : partitions) for(auto&& ins_part : partitions)
{ {
...@@ -187,7 +187,7 @@ struct stream_info ...@@ -187,7 +187,7 @@ struct stream_info
} }
} }
using weight_ins = std::pair<std::size_t, instruction_ref>; using weight_ins = std::pair<int, instruction_ref>;
struct compare_weight_ins struct compare_weight_ins
{ {
bool operator()(const weight_ins& x, const weight_ins& y) const bool operator()(const weight_ins& x, const weight_ins& y) const
...@@ -197,10 +197,10 @@ struct stream_info ...@@ -197,10 +197,10 @@ struct stream_info
} }
}; };
void sort(module& p, std::size_t) void sort(module& p, int)
{ {
std::set<weight_ins, compare_weight_ins> children; std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited; std::unordered_map<instruction_ref, int> visited;
auto last = std::prev(p.end()); auto last = std::prev(p.end());
auto mw = this->weights.at(last); auto mw = this->weights.at(last);
auto nw = mw / (p.size() + 1); auto nw = mw / (p.size() + 1);
...@@ -253,25 +253,25 @@ struct stream_info ...@@ -253,25 +253,25 @@ struct stream_info
} }
} }
void set_stream(const partition& p, std::size_t n) void set_stream(const partition& p, int n)
{ {
for(auto ins : p.instructions) for(auto ins : p.instructions)
if(iweights[ins] > 0) if(iweights[ins] > 0)
set_stream(ins, n); set_stream(ins, n);
} }
void set_stream(instruction_ref ins, std::size_t n) void set_stream(instruction_ref ins, int n)
{ {
assert(iweights[ins] > 0); assert(iweights[ins] > 0);
ins2stream[ins] = n; ins2stream[ins] = n;
} }
std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); } int get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
bool has_stream(instruction_ref ins) const { return contains(ins2stream, ins); } bool has_stream(instruction_ref ins) const { return contains(ins2stream, ins); }
template <class F> template <class F>
bool different(F f, std::size_t stream) const bool different(F f, int stream) const
{ {
bool result = false; bool result = false;
f([&](auto s) { f([&](auto s) {
...@@ -313,11 +313,11 @@ struct stream_info ...@@ -313,11 +313,11 @@ struct stream_info
}; };
} }
std::unordered_set<std::size_t> get_streams(instruction_ref ins) const std::unordered_set<int> get_streams(instruction_ref ins) const
{ {
if(has_stream(ins)) if(has_stream(ins))
return {get_stream(ins)}; return {get_stream(ins)};
std::unordered_set<std::size_t> result; std::unordered_set<int> result;
get_streams_from(ins, get_inputs())([&](auto s) { get_streams_from(ins, get_inputs())([&](auto s) {
result.insert(s); result.insert(s);
return true; return true;
...@@ -340,7 +340,7 @@ struct stream_info ...@@ -340,7 +340,7 @@ struct stream_info
std::vector<instruction_ref> get_recorded_instructions(instruction_ref start) std::vector<instruction_ref> get_recorded_instructions(instruction_ref start)
{ {
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
std::unordered_map<std::size_t, instruction_ref> m; std::unordered_map<int, instruction_ref> m;
fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
for(auto i : ins->inputs()) for(auto i : ins->inputs())
{ {
...@@ -424,8 +424,8 @@ struct stream_info ...@@ -424,8 +424,8 @@ struct stream_info
auto concur_ins = this->find_concurrent_instructions(p); auto concur_ins = this->find_concurrent_instructions(p);
// Compute an index for each instruction // Compute an index for each instruction
std::unordered_map<instruction_ref, std::size_t> ins2index; std::unordered_map<instruction_ref, int> ins2index;
std::size_t index_total = 0; int index_total = 0;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
ins2index[ins] = index_total++; ins2index[ins] = index_total++;
...@@ -544,10 +544,10 @@ void schedule::apply(module& p) const ...@@ -544,10 +544,10 @@ void schedule::apply(module& p) const
return; return;
// Schedule instructions // Schedule instructions
std::size_t wait_id = 0; int wait_id = 0;
std::unordered_map<instruction_ref, std::size_t> ins2wait; std::unordered_map<instruction_ref, int> ins2wait;
std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for; std::unordered_map<int, std::unordered_set<int>> waited_for;
std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited; std::unordered_map<instruction_ref, std::unordered_set<int>> ins2waited;
ins2wait.reserve(p.size()); ins2wait.reserve(p.size());
ins2waited.reserve(p.size()); ins2waited.reserve(p.size());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
......
...@@ -180,13 +180,13 @@ struct find_nested_slice ...@@ -180,13 +180,13 @@ struct find_nested_slice
{ {
auto matcher() const { return match::name("slice")(match::arg(0)(match::name("slice"))); } auto matcher() const { return match::name("slice")(match::arg(0)(match::name("slice"))); }
using axes_map = std::map<std::size_t, std::pair<std::size_t, std::size_t>>; using axes_map = std::map<int, std::pair<int, int>>;
static axes_map get_axes(instruction_ref ins) static axes_map get_axes(instruction_ref ins)
{ {
axes_map result; axes_map result;
auto op = any_cast<op::slice>(ins->get_operator()); auto op = any_cast<op::slice>(ins->get_operator());
for(std::size_t i = 0; i < op.axes.size(); i++) for(int i = 0; i < op.axes.size(); i++)
{ {
result[op.axes[i]] = std::make_pair(op.starts[i], op.ends[i]); result[op.axes[i]] = std::make_pair(op.starts[i], op.ends[i]);
} }
...@@ -297,7 +297,7 @@ struct find_nested_concat ...@@ -297,7 +297,7 @@ struct find_nested_concat
return match::name("concat")(match::any_of[match::inputs()](match::name("concat"))); return match::name("concat")(match::any_of[match::inputs()](match::name("concat")));
} }
static std::size_t get_axis(instruction_ref ins) static int get_axis(instruction_ref ins)
{ {
auto op = any_cast<op::concat>(ins->get_operator()); auto op = any_cast<op::concat>(ins->get_operator());
return op.axis; return op.axis;
...@@ -365,7 +365,7 @@ struct find_resize ...@@ -365,7 +365,7 @@ struct find_resize
} }
// output must be multiple of inputs // output must be multiple of inputs
std::vector<std::size_t> scales(in_lens.size()); std::vector<int> scales(in_lens.size());
std::transform( std::transform(
in_lens.begin(), in_lens.end(), out_lens.begin(), scales.begin(), [](auto x, auto y) { in_lens.begin(), in_lens.end(), out_lens.begin(), scales.begin(), [](auto x, auto y) {
return y / x; return y / x;
...@@ -394,7 +394,7 @@ struct find_resize ...@@ -394,7 +394,7 @@ struct find_resize
} }
// wrap up shapes for multibroadcast // wrap up shapes for multibroadcast
std::vector<std::pair<std::size_t, std::size_t>> dim_scales; std::vector<std::pair<int, int>> dim_scales;
std::transform(in_lens.begin(), std::transform(in_lens.begin(),
in_lens.end(), in_lens.end(),
out_lens.begin(), out_lens.begin(),
......
...@@ -18,7 +18,7 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat> ...@@ -18,7 +18,7 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat>
struct desc struct desc
{ {
dnnl::memory::desc dst; dnnl::memory::desc dst;
std::size_t axis = 1; int axis = 1;
std::vector<dnnl::memory::desc> srcs; std::vector<dnnl::memory::desc> srcs;
}; };
desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
...@@ -30,7 +30,7 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat> ...@@ -30,7 +30,7 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat>
{ {
srcs.push_back(m.at(MIGRAPHX_DNNL_PREFIX(ARG_MULTIPLE_SRC) + i)); srcs.push_back(m.at(MIGRAPHX_DNNL_PREFIX(ARG_MULTIPLE_SRC) + i));
} }
return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), std::size_t(op.axis), srcs}; return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), int(op.axis), srcs};
} }
auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const
......
...@@ -44,8 +44,8 @@ struct dnnl_convolution ...@@ -44,8 +44,8 @@ struct dnnl_convolution
std::transform( std::transform(
dilation.begin(), dilation.end(), dilation.begin(), [](auto x) { return x - 1; }); dilation.begin(), dilation.end(), dilation.begin(), [](auto x) { return x - 1; });
auto kdims = op.kdims(); auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims); std::vector<int> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end()); std::vector<int> padding_r(op.padding.begin() + kdims, op.padding.end());
return {dnnl::prop_kind::forward_inference, return {dnnl::prop_kind::forward_inference,
dnnl::algorithm::convolution_auto, dnnl::algorithm::convolution_auto,
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
......
...@@ -9,7 +9,7 @@ template <> ...@@ -9,7 +9,7 @@ template <>
struct hash<dnnl::algorithm> struct hash<dnnl::algorithm>
{ {
using argument_type = dnnl::algorithm; using argument_type = dnnl::algorithm;
using result_type = std::size_t; using result_type = int;
result_type operator()(const argument_type& x) const noexcept result_type operator()(const argument_type& x) const noexcept
{ {
return std::hash<underlying_type_t<argument_type>>{}( return std::hash<underlying_type_t<argument_type>>{}(
...@@ -52,7 +52,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t) ...@@ -52,7 +52,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t)
#pragma clang diagnostic pop #pragma clang diagnostic pop
#endif #endif
dnnl::memory::format_tag to_dnnl_memory_format_tag(std::size_t n) dnnl::memory::format_tag to_dnnl_memory_format_tag(int n)
{ {
switch(n) switch(n)
{ {
......
...@@ -98,7 +98,7 @@ struct find_post_ops ...@@ -98,7 +98,7 @@ struct find_post_ops
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
for(std::size_t i = 0; i < 4; i++) for(int i = 0; i < 4; i++)
{ {
match::find_matches(m, find_post_ops{ctx}); match::find_matches(m, find_post_ops{ctx});
dead_code_elimination{}.apply(m); dead_code_elimination{}.apply(m);
......
...@@ -29,7 +29,7 @@ struct cpu_gather : auto_register_op<cpu_gather> ...@@ -29,7 +29,7 @@ struct cpu_gather : auto_register_op<cpu_gather>
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
std::size_t nelements = output_shape.elements(); int nelements = output_shape.elements();
auto lens = args[0].get_shape().lens(); auto lens = args[0].get_shape().lens();
auto axis_dim_size = lens[op.axis]; auto axis_dim_size = lens[op.axis];
lens[op.axis] = args[1].get_shape().elements(); lens[op.axis] = args[1].get_shape().elements();
......
...@@ -15,13 +15,13 @@ struct context ...@@ -15,13 +15,13 @@ struct context
void finish() const {} void finish() const {}
template <class F> template <class F>
void bulk_execute(std::size_t n, std::size_t min_grain, F f) void bulk_execute(int n, int min_grain, F f)
{ {
cpu::parallel_for(n, min_grain, f); cpu::parallel_for(n, min_grain, f);
} }
template <class F> template <class F>
void bulk_execute(std::size_t n, F f) void bulk_execute(int n, F f)
{ {
this->bulk_execute(n, 256, f); this->bulk_execute(n, 256, f);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment