Commit edc23800 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the data type for lens and strides from size_t to int in the shape class

parent c7419a9c
...@@ -20,15 +20,15 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -20,15 +20,15 @@ inline namespace MIGRAPHX_INLINE_NS {
// In this case we need to broadcast the (:,:,1:,:) axis // In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving // of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5) // output_lens = (3,2,7,5)
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<int> compute_broadcasted_lens(std::vector<int> s0,
std::vector<std::size_t> s1) std::vector<int> s1)
{ {
if(s0 == s1) if(s0 == s1)
return s0; return s0;
if(s0.size() > s1.size()) if(s0.size() > s1.size())
s0.swap(s1); s0.swap(s1);
std::vector<std::size_t> out_lens(s1); std::vector<int> out_lens(s1);
auto offset = s1.size() - s0.size(); auto offset = s1.size() - s0.size();
std::transform( std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) { s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
...@@ -43,7 +43,7 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, ...@@ -43,7 +43,7 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return out_lens; return out_lens;
} }
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes) std::vector<int> compute_common_lens(const std::vector<shape>& shapes)
{ {
assert(not shapes.empty()); assert(not shapes.empty());
return transform_accumulate(shapes.begin() + 1, return transform_accumulate(shapes.begin() + 1,
......
...@@ -17,8 +17,8 @@ void eliminate_allocation::apply(module& p) const ...@@ -17,8 +17,8 @@ void eliminate_allocation::apply(module& p) const
{ {
assert(alignment > 0); assert(alignment > 0);
std::size_t n = 0; int n = 0;
std::vector<std::pair<instruction_ref, std::size_t>> allocs; std::vector<std::pair<instruction_ref, int>> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->name() != allocation_op) if(ins->name() != allocation_op)
......
...@@ -36,7 +36,7 @@ void eliminate_concat::apply(module& p) const ...@@ -36,7 +36,7 @@ void eliminate_concat::apply(module& p) const
// we only need to check the first input // we only need to check the first input
auto lens = ins->inputs().front()->get_shape().lens(); auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator()); auto concat_op = concat_opt.get_concat(ins->get_operator());
std::size_t axis_index = tune_axis(lens.size(), concat_op.axis, concat_op.name()); int axis_index = tune_axis(lens.size(), concat_op.axis, concat_op.name());
if(axis_index == 0 || if(axis_index == 0 ||
std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; })) std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; }))
{ {
...@@ -70,7 +70,7 @@ void eliminate_concat::apply(module& p) const ...@@ -70,7 +70,7 @@ void eliminate_concat::apply(module& p) const
auto first = sorted_allocations.front(); auto first = sorted_allocations.front();
auto super = p.move_instruction(last, first); auto super = p.move_instruction(last, first);
// Replace each allocation with a load // Replace each allocation with a load
std::size_t offset = 0; int offset = 0;
for(auto alloc : allocations) for(auto alloc : allocations)
{ {
op::load op{alloc->get_shape(), offset}; op::load op{alloc->get_shape(), offset};
......
...@@ -39,7 +39,7 @@ struct check_shapes ...@@ -39,7 +39,7 @@ struct check_shapes
return name + ": "; return name + ": ";
} }
std::size_t size() const int size() const
{ {
if(begin == end) if(begin == end)
return 0; return 0;
...@@ -57,14 +57,14 @@ struct check_shapes ...@@ -57,14 +57,14 @@ struct check_shapes
return *this; return *this;
} }
const check_shapes& nelements(std::size_t n) const const check_shapes& nelements(int n) const
{ {
if(!this->all_of([&](const shape& s) { return s.elements() == n; })) if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements"); MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements");
return *this; return *this;
} }
const check_shapes& only_dims(std::size_t n) const const check_shapes& only_dims(int n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
...@@ -76,7 +76,7 @@ struct check_shapes ...@@ -76,7 +76,7 @@ struct check_shapes
return *this; return *this;
} }
const check_shapes& max_ndims(std::size_t n) const const check_shapes& max_ndims(int n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
...@@ -89,7 +89,7 @@ struct check_shapes ...@@ -89,7 +89,7 @@ struct check_shapes
return *this; return *this;
} }
const check_shapes& min_ndims(std::size_t n) const const check_shapes& min_ndims(int n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
...@@ -179,7 +179,7 @@ struct check_shapes ...@@ -179,7 +179,7 @@ struct check_shapes
return *this; return *this;
} }
const check_shapes& elements(std::size_t n) const const check_shapes& elements(int n) const
{ {
if(!this->all_of([&](const shape& s) { return s.elements() == n; })) if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Wrong number of elements"); MIGRAPHX_THROW(prefix() + "Wrong number of elements");
...@@ -230,13 +230,13 @@ struct check_shapes ...@@ -230,13 +230,13 @@ struct check_shapes
check_shapes slice(long start, long last) const { return {get(start), get(last), name}; } check_shapes slice(long start, long last) const { return {get(start), get(last), name}; }
private: private:
static bool batch_not_transposed_strides(const std::vector<std::size_t>& strides) static bool batch_not_transposed_strides(const std::vector<int>& strides)
{ {
if(strides.size() <= 2) if(strides.size() <= 2)
return true; return true;
auto dim_0 = strides.size() - 2; auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]); auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0); std::vector<int> batch(strides.begin(), strides.begin() + dim_0);
if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); })) if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); }))
{ {
return false; return false;
......
...@@ -11,8 +11,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,8 +11,8 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
struct operation; struct operation;
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<int> compute_broadcasted_lens(std::vector<int> s0,
std::vector<std::size_t> s1); std::vector<int> s1);
shape common_shape(const std::vector<shape>& shapes); shape common_shape(const std::vector<shape>& shapes);
instruction_ref insert_common_op(module& m, instruction_ref insert_common_op(module& m,
......
...@@ -19,7 +19,7 @@ struct select_dependent_type ...@@ -19,7 +19,7 @@ struct select_dependent_type
template <class T, class... Ts> template <class T, class... Ts>
using dependent_type = typename select_dependent_type<T, Ts...>::type; using dependent_type = typename select_dependent_type<T, Ts...>::type;
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens); bool normalize_attributes(operation& op, const std::vector<int>& lens);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -11,9 +11,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,9 +11,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct onnx_options struct onnx_options
{ {
/// default batch size to use (if not specified in onnx file) /// default batch size to use (if not specified in onnx file)
std::size_t default_dim_value = 1; int default_dim_value = 1;
/// Explicitly specify the dims of an input /// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {}; std::unordered_map<std::string, std::vector<int>> map_input_dims = {};
/// Continue parsing onnx file if an unknown operator is found /// Continue parsing onnx file if an unknown operator is found
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
/// Print program if an error occurs /// Print program if an error occurs
...@@ -29,7 +29,7 @@ program parse_onnx(const std::string& name, const onnx_options& = onnx_options{} ...@@ -29,7 +29,7 @@ program parse_onnx(const std::string& name, const onnx_options& = onnx_options{}
program parse_onnx_buffer(const std::string& buffer, const onnx_options& options); program parse_onnx_buffer(const std::string& buffer, const onnx_options& options);
/// Create a program from an onnx buffer /// Create a program from an onnx buffer
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options); program parse_onnx_buffer(const void* data, int size, const onnx_options& options);
std::vector<std::string> get_onnx_operators(); std::vector<std::string> get_onnx_operators();
......
...@@ -44,11 +44,11 @@ struct argmax ...@@ -44,11 +44,11 @@ struct argmax
} }
template <class T> template <class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const int64_t calc_argmax(T& input, std::vector<int>& indices, int item_num) const
{ {
auto max_val = input(indices.begin(), indices.end()); auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0; int64_t max_index = 0;
for(std::size_t i = 1; i < item_num; ++i) for(int i = 1; i < item_num; ++i)
{ {
indices[axis] = i; indices[axis] = i;
auto cur_val = input(indices.begin(), indices.end()); auto cur_val = input(indices.begin(), indices.end());
......
...@@ -44,11 +44,11 @@ struct argmin ...@@ -44,11 +44,11 @@ struct argmin
} }
template <class T> template <class T>
int64_t calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num) const int64_t calc_argmin(T& input, std::vector<int>& indices, int item_num) const
{ {
auto min_val = input(indices.begin(), indices.end()); auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0; int64_t min_index = 0;
for(std::size_t i = 1; i < item_num; ++i) for(int i = 1; i < item_num; ++i)
{ {
indices[axis] = i; indices[axis] = i;
auto cur_val = input(indices.begin(), indices.end()); auto cur_val = input(indices.begin(), indices.end());
...@@ -65,7 +65,7 @@ struct argmin ...@@ -65,7 +65,7 @@ struct argmin
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
std::size_t batch_item_num = args.front().get_shape().lens()[axis]; int batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
......
...@@ -25,7 +25,7 @@ namespace op { ...@@ -25,7 +25,7 @@ namespace op {
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens; std::vector<int> broadcast_lens;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -39,7 +39,7 @@ struct broadcast ...@@ -39,7 +39,7 @@ struct broadcast
auto input = inputs.at(0); auto input = inputs.at(0);
auto t = input.type(); auto t = input.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); std::vector<int> bcast_strides(broadcast_lens.size(), 0);
// the broacast op is deprecated now, so not handling the negative // the broacast op is deprecated now, so not handling the negative
// value of axis anymore // value of axis anymore
if(axis >= broadcast_lens.size()) if(axis >= broadcast_lens.size())
......
...@@ -37,12 +37,12 @@ struct concat ...@@ -37,12 +37,12 @@ struct concat
} }
std::string name() const { return "concat"; } std::string name() const { return "concat"; }
std::vector<std::size_t> compute_offsets(const shape& output_shape, std::vector<int> compute_offsets(const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
auto n_dims = args[0].get_shape().lens().size(); auto n_dims = args[0].get_shape().lens().size();
std::vector<std::size_t> offsets; std::vector<int> offsets;
std::vector<std::size_t> offset(n_dims, 0); std::vector<int> offset(n_dims, 0);
offset[axis] = 0; offset[axis] = 0;
for(const auto& arg : args) for(const auto& arg : args)
{ {
...@@ -60,7 +60,7 @@ struct concat ...@@ -60,7 +60,7 @@ struct concat
const auto& first_shape_lens = inputs.front().lens(); const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type(); const auto& type = inputs.front().type();
for(std::size_t l = 0; l < first_shape_lens.size(); l++) for(int l = 0; l < first_shape_lens.size(); l++)
{ {
if(l != axis) if(l != axis)
{ {
...@@ -72,13 +72,13 @@ struct concat ...@@ -72,13 +72,13 @@ struct concat
} }
} }
} }
std::size_t new_dim_axis = 0; int new_dim_axis = 0;
for(const auto& input : inputs) for(const auto& input : inputs)
{ {
const auto& lens = input.lens(); const auto& lens = input.lens();
new_dim_axis += lens[axis]; new_dim_axis += lens[axis];
} }
std::vector<std::size_t> new_lens; std::vector<int> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens)); std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis] = new_dim_axis; new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs)); return shape::from_permutation(type, new_lens, find_permutation(inputs));
...@@ -86,8 +86,8 @@ struct concat ...@@ -86,8 +86,8 @@ struct concat
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args); std::vector<int> coffsets = compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++) for(int l = 0; l < args.size(); l++)
{ {
auto argl = args[l]; auto argl = args[l];
visit_all(result, argl)([&](auto output, auto input) { visit_all(result, argl)([&](auto output, auto input) {
......
...@@ -20,9 +20,9 @@ namespace op { ...@@ -20,9 +20,9 @@ namespace op {
struct convolution struct convolution
{ {
std::vector<std::size_t> padding = {0, 0}; std::vector<int> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1}; std::vector<int> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1}; std::vector<int> dilation = {1, 1};
int group = 1; int group = 1;
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
...@@ -64,7 +64,7 @@ struct convolution ...@@ -64,7 +64,7 @@ struct convolution
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
size_t kdims = input_size - 2; int kdims = input_size - 2;
if(kdims != this->kdims()) if(kdims != this->kdims())
{ {
MIGRAPHX_THROW("convolution: input k-dims does not match attribute size"); MIGRAPHX_THROW("convolution: input k-dims does not match attribute size");
...@@ -73,14 +73,14 @@ struct convolution ...@@ -73,14 +73,14 @@ struct convolution
if(input.lens().at(1) != (weights.lens().at(1) * group)) if(input.lens().at(1) != (weights.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers"); MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers");
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]}; std::vector<int> output_lens{input.lens()[0], weights.lens()[0]};
for(size_t i = 0; i < kdims; i++) for(int i = 0; i < kdims; i++)
{ {
auto padding_factor = 2 * padding[i]; auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims) if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims]; padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>( output_lens.push_back(int(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + (input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
padding_factor) / padding_factor) /
...@@ -91,7 +91,7 @@ struct convolution ...@@ -91,7 +91,7 @@ struct convolution
return inputs[0].with_lens(output_lens); return inputs[0].with_lens(output_lens);
} }
size_t kdims() const int kdims() const
{ {
check_attribute_size(); check_attribute_size();
return stride.size(); return stride.size();
......
...@@ -20,9 +20,9 @@ namespace op { ...@@ -20,9 +20,9 @@ namespace op {
struct deconvolution struct deconvolution
{ {
std::vector<std::size_t> padding = {0, 0}; std::vector<int> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1}; std::vector<int> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1}; std::vector<int> dilation = {1, 1};
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
int group = 1; int group = 1;
...@@ -54,17 +54,17 @@ struct deconvolution ...@@ -54,17 +54,17 @@ struct deconvolution
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
size_t kdims = input.lens().size() - 2; int kdims = input.lens().size() - 2;
if(kdims != this->kdims()) if(kdims != this->kdims())
{ {
MIGRAPHX_THROW("deconvolution: input k-dims does not match attribute size"); MIGRAPHX_THROW("deconvolution: input k-dims does not match attribute size");
} }
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[1]}; std::vector<int> output_lens{input.lens()[0], weights.lens()[1]};
for(size_t i = 0; i < kdims; i++) for(int i = 0; i < kdims; i++)
{ {
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>( output_lens.push_back(int(std::max<std::ptrdiff_t>(
1, 1,
stride[i] * (input.lens()[i + 2] - 1) + stride[i] * (input.lens()[i + 2] - 1) +
((weights.lens()[i + 2] - 1) * dilation[i] + 1) - 2 * padding[i]))); ((weights.lens()[i + 2] - 1) * dilation[i] + 1) - 2 * padding[i])));
...@@ -91,7 +91,7 @@ struct deconvolution ...@@ -91,7 +91,7 @@ struct deconvolution
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
std::vector<std::size_t> win_size{in_c}; std::vector<int> win_size{in_c};
std::copy(in_lens.begin() + 2, in_lens.end(), std::back_inserter(win_size)); std::copy(in_lens.begin() + 2, in_lens.end(), std::back_inserter(win_size));
std::copy(wei.begin() + 2, wei.end(), std::back_inserter(win_size)); std::copy(wei.begin() + 2, wei.end(), std::back_inserter(win_size));
shape win_shape{output_shape.type(), win_size}; shape win_shape{output_shape.type(), win_size};
...@@ -105,7 +105,7 @@ struct deconvolution ...@@ -105,7 +105,7 @@ struct deconvolution
auto wei_dims_start = idx_win.begin() + kdims + 1; auto wei_dims_start = idx_win.begin() + kdims + 1;
std::vector<std::ptrdiff_t> win_start; std::vector<std::ptrdiff_t> win_start;
for(std::size_t n = 0; n < kdims; ++n) for(int n = 0; n < kdims; ++n)
{ {
win_start.push_back(std::ptrdiff_t(*(input_dims_start + n) * stride[n]) - win_start.push_back(std::ptrdiff_t(*(input_dims_start + n) * stride[n]) -
std::ptrdiff_t(padding[n])); std::ptrdiff_t(padding[n]));
...@@ -116,7 +116,7 @@ struct deconvolution ...@@ -116,7 +116,7 @@ struct deconvolution
std::vector<std::ptrdiff_t> idx_out{o, in_ch}; std::vector<std::ptrdiff_t> idx_out{o, in_ch};
for(size_t n = 0; n < kdims; n++) for(int n = 0; n < kdims; n++)
{ {
idx_out.push_back(win_start[n] + *(wei_dims_start + n) * dilation[n]); idx_out.push_back(win_start[n] + *(wei_dims_start + n) * dilation[n]);
} }
...@@ -147,7 +147,7 @@ struct deconvolution ...@@ -147,7 +147,7 @@ struct deconvolution
return result; return result;
} }
size_t kdims() const int kdims() const
{ {
check_attribute_size(); check_attribute_size();
return stride.size(); return stride.size();
......
...@@ -42,9 +42,9 @@ struct flatten ...@@ -42,9 +42,9 @@ struct flatten
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto&& lens = inputs.front().lens(); auto&& lens = inputs.front().lens();
auto x = auto x =
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{}); std::accumulate(lens.begin(), lens.begin() + axis, int{1}, std::multiplies<>{});
auto y = auto y =
std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{}); std::accumulate(lens.begin() + axis, lens.end(), int{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}}; return {inputs.at(0).type(), {x, y}};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -21,7 +21,7 @@ namespace op { ...@@ -21,7 +21,7 @@ namespace op {
struct gru struct gru
{ {
std::size_t hidden_size = 1; int hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}}; std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
rnn_direction direction = rnn_direction::forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
...@@ -47,7 +47,7 @@ struct gru ...@@ -47,7 +47,7 @@ struct gru
MIGRAPHX_THROW("GRU: hidden size mismatch in attribute and input"); MIGRAPHX_THROW("GRU: hidden size mismatch in attribute and input");
} }
std::size_t num_directions = 1; int num_directions = 1;
if(direction == rnn_direction::bidirectional) if(direction == rnn_direction::bidirectional)
{ {
num_directions = 2; num_directions = 2;
...@@ -58,7 +58,7 @@ struct gru ...@@ -58,7 +58,7 @@ struct gru
MIGRAPHX_THROW("GRU: num_direction does not match the direction attribute"); MIGRAPHX_THROW("GRU: num_direction does not match the direction attribute");
} }
std::vector<std::size_t> out_dims(in_dims); std::vector<int> out_dims(in_dims);
out_dims.insert(out_dims.begin() + 1, num_directions); out_dims.insert(out_dims.begin() + 1, num_directions);
out_dims.back() = hidden_size; out_dims.back() = hidden_size;
......
...@@ -14,9 +14,9 @@ namespace op { ...@@ -14,9 +14,9 @@ namespace op {
struct im2col struct im2col
{ {
std::vector<std::size_t> padding{0, 0}; std::vector<int> padding{0, 0};
std::vector<std::size_t> stride{1, 1}; std::vector<int> stride{1, 1};
std::vector<std::size_t> dilation{1, 1}; std::vector<int> dilation{1, 1};
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
...@@ -52,11 +52,11 @@ struct im2col ...@@ -52,11 +52,11 @@ struct im2col
padding_h = padding[0] + padding[2]; padding_h = padding[0] + padding[2];
padding_w = padding[1] + padding[3]; padding_w = padding[1] + padding[3];
} }
auto output_height = std::size_t(std::max<std::ptrdiff_t>( auto output_height = int(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + padding_h) / stride[0] + (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + padding_h) / stride[0] +
1)); 1));
auto output_width = std::size_t(std::max<std::ptrdiff_t>( auto output_width = int(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + padding_w) / stride[1] + (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + padding_w) / stride[1] +
1)); 1));
......
...@@ -17,7 +17,7 @@ namespace op { ...@@ -17,7 +17,7 @@ namespace op {
struct load struct load
{ {
shape s; shape s;
std::size_t offset = 0; int offset = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
......
...@@ -21,7 +21,7 @@ namespace op { ...@@ -21,7 +21,7 @@ namespace op {
struct lstm struct lstm
{ {
std::size_t hidden_size = 1; int hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}}; std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
rnn_direction direction = rnn_direction::forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
...@@ -47,7 +47,7 @@ struct lstm ...@@ -47,7 +47,7 @@ struct lstm
MIGRAPHX_THROW("LSTM: hidden size mismatch in attribute and input"); MIGRAPHX_THROW("LSTM: hidden size mismatch in attribute and input");
} }
std::size_t num_directions = 1; int num_directions = 1;
if(direction == rnn_direction::bidirectional) if(direction == rnn_direction::bidirectional)
{ {
num_directions = 2; num_directions = 2;
...@@ -58,7 +58,7 @@ struct lstm ...@@ -58,7 +58,7 @@ struct lstm
MIGRAPHX_THROW("LSTM: num_direction does not match the direction attribute"); MIGRAPHX_THROW("LSTM: num_direction does not match the direction attribute");
} }
std::vector<std::size_t> out_dims(in_dims); std::vector<int> out_dims(in_dims);
out_dims.insert(out_dims.begin() + 1, num_directions); out_dims.insert(out_dims.begin() + 1, num_directions);
out_dims.back() = hidden_size; out_dims.back() = hidden_size;
......
...@@ -24,7 +24,7 @@ struct multinomial ...@@ -24,7 +24,7 @@ struct multinomial
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).only_dims(2); check_shapes{inputs, *this}.has(2).only_dims(2);
size_t sample_size = inputs.back().lens().back(); int sample_size = inputs.back().lens().back();
if(not contains({shape::int32_type, shape::int64_type}, dtype)) if(not contains({shape::int32_type, shape::int64_type}, dtype))
MIGRAPHX_THROW( MIGRAPHX_THROW(
...@@ -36,9 +36,9 @@ struct multinomial ...@@ -36,9 +36,9 @@ struct multinomial
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
size_t batch_size = output_shape.lens().front(); int batch_size = output_shape.lens().front();
size_t class_size = args[0].get_shape().lens().back(); int class_size = args[0].get_shape().lens().back();
size_t sample_size = output_shape.lens().back(); int sample_size = output_shape.lens().back();
visit_all(args[0], args[1])([&](auto cdf, auto dist) { visit_all(args[0], args[1])([&](auto cdf, auto dist) {
result.visit([&](auto output) { result.visit([&](auto output) {
......
...@@ -21,15 +21,15 @@ struct nonzero ...@@ -21,15 +21,15 @@ struct nonzero
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto elem_num = inputs[0].elements(); auto elem_num = inputs[0].elements();
auto dim_num = inputs[0].lens().size(); int dim_num = inputs[0].lens().size();
std::vector<std::size_t> out_lens = {dim_num, elem_num}; std::vector<int> out_lens = {dim_num, elem_num};
return {shape::int64_type, out_lens}; return {shape::int64_type, out_lens};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
std::vector<std::vector<std::size_t>> vec_idx; std::vector<std::vector<int>> vec_idx;
auto s = args.front().get_shape(); auto s = args.front().get_shape();
args.front().visit([&](auto v) { args.front().visit([&](auto v) {
shape_for_each(s, [&](auto idx) { shape_for_each(s, [&](auto idx) {
...@@ -44,7 +44,7 @@ struct nonzero ...@@ -44,7 +44,7 @@ struct nonzero
result.visit([&](auto output) { result.visit([&](auto output) {
std::fill(output.begin(), output.end(), 0); std::fill(output.begin(), output.end(), 0);
par_for(vec_idx.size(), [&](auto i) { par_for(vec_idx.size(), [&](auto i) {
for(std::size_t j = 0; j < vec_idx.front().size(); ++j) for(int j = 0; j < vec_idx.front().size(); ++j)
{ {
output[output_shape.index({j, i})] = vec_idx[i][j]; output[output_shape.index({j, i})] = vec_idx[i][j];
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment