Commit e2eb6036 authored by Paul's avatar Paul
Browse files

Merge

parents 298c93d5 1e0bbd78
......@@ -9,6 +9,7 @@
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/value.hpp>
#include <cmath>
#include <utility>
......@@ -26,6 +27,8 @@ struct reshape
return pack(f(self.dims, "dims"));
}
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -72,7 +75,6 @@ struct reshape
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -3,6 +3,7 @@
#include <limits>
#include <migraphx/check_shapes.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
......@@ -21,7 +22,7 @@ namespace op {
struct roialign
{
std::string coord_trans_mode = "half_pixel";
std::string mode = "avg";
pooling_mode mode = {pooling_mode::average};
int64_t output_height = 1;
int64_t output_width = 1;
int64_t sampling_ratio = 0;
......@@ -42,7 +43,7 @@ struct roialign
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).standard();
check_shapes{inputs, *this}.has(3);
auto x_lens = inputs.at(0).lens();
auto roi_lens = inputs.at(1).lens();
auto bi_lens = inputs.at(2).lens();
......@@ -241,19 +242,19 @@ struct roialign
in_dims[0] * in_dims[1]);
double output_val;
std::tie(output_val, vec_index[c]) =
(mode == "avg") ? this->calc_pooling(offset_bottom_data,
bin_grid_size,
pre_calc,
vec_index[c],
avg_pool{})
: this->calc_pooling(offset_bottom_data,
bin_grid_size,
pre_calc,
vec_index[c],
max_pool{});
(mode == migraphx::op::pooling_mode::average)
? this->calc_pooling(offset_bottom_data,
bin_grid_size,
pre_calc,
vec_index[c],
avg_pool{})
: this->calc_pooling(offset_bottom_data,
bin_grid_size,
pre_calc,
vec_index[c],
max_pool{});
output(n, c, ph, pw) = output_val;
});
});
});
......
......@@ -40,7 +40,6 @@ struct scalar
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_add : scatternd_op<scatternd_add>
{
scatternd_add() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x += y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_mul : scatternd_op<scatternd_mul>
{
scatternd_mul() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x *= y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_none : scatternd_op<scatternd_none>
{
scatternd_none() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x = y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
template <class Derived>
struct scatternd_op : op_name<Derived>
{
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3);
auto r = inputs.front().lens().size();
auto q = inputs.at(1).lens().size();
auto k = inputs.at(1).lens().back();
auto ind_lens = inputs.at(1).lens();
auto upd_lens = inputs.back().lens();
auto data_lens = inputs.front().lens();
if(k > r)
MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) +
" is too large for tensor of rank " + std::to_string(r));
if(not(std::equal(ind_lens.begin(), ind_lens.begin() + q - 1, upd_lens.begin()) and
std::equal(data_lens.begin() + k, data_lens.end(), upd_lens.begin() + q - 1)))
MIGRAPHX_THROW("ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] "
"++ data.lens[k:r-1]");
auto s = inputs.front();
if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto& self = static_cast<const Derived&>(*this);
visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) {
std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) {
auto updates_shape = updates.get_shape();
auto updates_std = shape{updates_shape.type(), updates_shape.lens()};
auto indices_shape = indices.get_shape();
auto k = indices_shape.lens().back();
auto q = indices_shape.lens().size();
auto r = output_shape.lens().size();
par_for(updates_shape.elements(), [&](const auto i) {
auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0);
std::copy(
updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin());
auto index_start = indices.begin() +
indices_shape.index(indices_idx.begin(), indices_idx.end());
auto index_end = index_start + k;
std::vector<std::size_t> out_idx(r, 0);
std::copy(index_start, index_end, out_idx.begin());
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[output_shape.index(out_idx)], updates[i]);
});
});
});
return result;
}
auto init() const {}
scatternd_op() {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -37,48 +37,53 @@ struct squeeze
std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty())
{
std::copy_if(old_lens.begin(),
old_lens.end(),
std::back_inserter(new_lens),
[](auto len) { return len != 1; });
for(auto i : range(old_lens.size()))
{
if(old_lens[i] != 1)
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
}
else
{
for(std::size_t i = 0; i < old_lens.size(); i++)
for(auto i : range(old_lens.size()))
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
}
if(new_lens.empty())
{
return shape{type};
}
else
{
return shape{type, new_lens};
return shape{type, new_lens, new_strides};
}
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -72,8 +72,6 @@ struct step
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -56,7 +56,6 @@ struct transpose
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -64,7 +64,6 @@ struct unary : op_name<Derived>
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
});
return result;
......
......@@ -37,11 +37,11 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard_or_scalar();
check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar())
{
if(old_lens.size() == 1 and old_lens.front() == 1)
......@@ -53,25 +53,34 @@ struct unsqueeze
std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0;
for(std::size_t i = 0; i < new_size; i++)
for(auto i : range(new_size))
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
new_lens[i] = 1;
if(p == 0) // unsqueeze on the first axes
{
new_strides[i] = old_lens[0] * old_strides[0];
}
else // unsqueeze on middle or last axes
{
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1;
}
}
else
{
new_lens[i] = old_lens[p++];
new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
}
}
return shape{type, new_lens};
return shape{type, new_lens, new_strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -445,35 +445,62 @@ lifetime get_lifetime_op(const T&)
} // namespace detail
/*
* Type-erased interface for:
*
* struct operation
* {
* std::string name() const;
* bool is_context_free() const;
* bool need_normalization() const;
* bool has_finalize() const;
* lifetime get_lifetime() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
* shape compute_shape(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>&
* input) const; argument compute(const shape& output,const std::vector<argument>& input)
* const; argument compute(const shape& output,const std::vector<argument>& input,const
* std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const
* shape& output,const std::vector<argument>& input,const std::vector<module_ref>&
* module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>&)> run) const; value to_value() const; void
* from_value(const value& v) ; value attributes() const; friend std::ostream &
* operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation &
* x,const operation & y) ;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct operation
{
//
std::string name() const;
// (optional)
bool is_context_free() const;
// (optional)
bool need_normalization() const;
// (optional)
bool has_finalize() const;
// (optional)
lifetime get_lifetime() const;
// (optional)
std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
// (optional)
value compile(context& ctx, const shape& output, const std::vector<shape>& input);
// (optional)
void finalize(context& ctx, const shape& output, const std::vector<shape>& input);
// (optional)
shape compute_shape(const std::vector<shape>& input) const;
// (optional)
shape compute_shape(const std::vector<shape>& inputs,
const std::vector<module_ref>& mod_args) const;
// (optional)
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
// (optional)
argument compute(const shape& output, const std::vector<argument>& input) const;
// (optional)
argument compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const;
// (optional)
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const;
// (optional)
value to_value() const;
// (optional)
void from_value(const value& v);
// (optional)
value attributes() const;
//
friend std::ostream& operator<<(std::ostream& os, const operation& op);
//
friend bool operator==(const operation& x, const operation& y);
};
#else
struct operation
{
......@@ -1222,6 +1249,7 @@ inline const ValueType& any_cast(const operation& x)
throw std::bad_cast();
return *y;
}
#endif
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
......
......@@ -41,6 +41,7 @@
#include <migraphx/op/identity.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp>
#include <migraphx/op/load.hpp>
......@@ -86,6 +87,9 @@
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp>
......
......@@ -41,7 +41,6 @@ auto par_dfor(Ts... xs)
{
dfor(xs...)(f);
}
};
}
......
......@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize =
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain);
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f);
}
......
......@@ -58,17 +58,20 @@ void module_pass_manager_apply(const T& x, module_pass_manager& mpm)
} // namespace detail
/*
* Type-erased interface for:
*
* struct pass
* {
* std::string name() const;
* void apply(module_pass_manager & mpm) const;
* void apply(program & p) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct pass
{
//
std::string name() const;
// (optional)
void apply(module_pass_manager& mpm) const;
// (optional)
void apply(program& p) const;
};
#else
struct pass
{
......@@ -303,6 +306,7 @@ inline const ValueType& any_cast(const pass& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
......@@ -81,6 +81,9 @@ struct program
const std::function<void(instruction_ref,
std::unordered_map<instruction_ref, std::string>)>&
print_func) const;
void print(const std::function<void(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>)>&
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const;
void print_cpp(std::ostream& os) const;
......
......@@ -26,30 +26,35 @@ struct schedule_model
/// Get the number of concurrent instruction allowed
std::size_t concurrency() const;
/// Schedule a concurrent instruction
void sched(module& p, instruction_ref ins, std::size_t n) const;
void sched(module& m, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction
void wait(module& p, instruction_ref ins, std::size_t wait_id) const;
void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction
void record(module& p, instruction_ref ins, std::size_t wait_id) const;
void record(module& m, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation
std::size_t weight(const operation& op) const;
};
#else
/*
* Type-erased interface for:
*
* struct schedule_model
* {
* std::size_t concurrency() const;
* void sched(module& p,instruction_ref ins,std::size_t n) const;
* void wait(module& p,instruction_ref ins,std::size_t wait_id) const;
* void record(module& p,instruction_ref ins,std::size_t wait_id) const;
* std::size_t weight(const operation& op) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct schedule_model
{
//
std::size_t concurrency() const;
//
void sched(module& m, instruction_ref ins, std::size_t n) const;
//
void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
//
void record(module& m, instruction_ref ins, std::size_t wait_id) const;
//
std::size_t weight(const operation& op) const;
};
#else
struct schedule_model
{
......@@ -120,22 +125,22 @@ struct schedule_model
return (*this).private_detail_te_get_handle().concurrency();
}
void sched(module& p, instruction_ref ins, std::size_t n) const
void sched(module& m, instruction_ref ins, std::size_t n) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().sched(p, ins, n);
(*this).private_detail_te_get_handle().sched(m, ins, n);
}
void wait(module& p, instruction_ref ins, std::size_t wait_id) const
void wait(module& m, instruction_ref ins, std::size_t wait_id) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait(p, ins, wait_id);
(*this).private_detail_te_get_handle().wait(m, ins, wait_id);
}
void record(module& p, instruction_ref ins, std::size_t wait_id) const
void record(module& m, instruction_ref ins, std::size_t wait_id) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().record(p, ins, wait_id);
(*this).private_detail_te_get_handle().record(m, ins, wait_id);
}
std::size_t weight(const operation& op) const
......@@ -159,9 +164,9 @@ struct schedule_model
virtual const std::type_info& type() const = 0;
virtual std::size_t concurrency() const = 0;
virtual void sched(module& p, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(module& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(module& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void sched(module& m, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(module& m, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(module& m, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0;
};
......@@ -195,22 +200,22 @@ struct schedule_model
std::size_t concurrency() const override { return private_detail_te_value.concurrency(); }
void sched(module& p, instruction_ref ins, std::size_t n) const override
void sched(module& m, instruction_ref ins, std::size_t n) const override
{
private_detail_te_value.sched(p, ins, n);
private_detail_te_value.sched(m, ins, n);
}
void wait(module& p, instruction_ref ins, std::size_t wait_id) const override
void wait(module& m, instruction_ref ins, std::size_t wait_id) const override
{
private_detail_te_value.wait(p, ins, wait_id);
private_detail_te_value.wait(m, ins, wait_id);
}
void record(module& p, instruction_ref ins, std::size_t wait_id) const override
void record(module& m, instruction_ref ins, std::size_t wait_id) const override
{
private_detail_te_value.record(p, ins, wait_id);
private_detail_te_value.record(m, ins, wait_id);
}
std::size_t weight(const operation& op) const override
......@@ -283,6 +288,7 @@ inline const ValueType& any_cast(const schedule_model& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
......@@ -35,7 +35,7 @@ struct shape
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
// clang-format on
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
......@@ -131,6 +131,8 @@ struct shape
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
shape with_lens(const std::vector<std::size_t>& l) const;
shape with_type(type_t t) const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x);
......@@ -190,8 +192,7 @@ struct shape
{
switch(t)
{
case tuple_type:
{
case tuple_type: {
tv();
return;
}
......@@ -228,10 +229,11 @@ struct shape
const std::vector<shape>& sub_shapes() const;
std::size_t element_space() const;
private:
shape(std::shared_ptr<shape_impl> pimpl);
std::shared_ptr<const shape_impl> impl;
std::size_t element_space() const;
};
void migraphx_to_value(value& v, const shape& s);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment