Commit e2eb6036 authored by Paul's avatar Paul
Browse files

Merge

parents 298c93d5 1e0bbd78
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp> #include <migraphx/lifetime.hpp>
#include <migraphx/value.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -26,6 +27,8 @@ struct reshape ...@@ -26,6 +27,8 @@ struct reshape
return pack(f(self.dims, "dims")); return pack(f(self.dims, "dims"));
} }
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -72,7 +75,6 @@ struct reshape ...@@ -72,7 +75,6 @@ struct reshape
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <limits> #include <limits>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
...@@ -21,7 +22,7 @@ namespace op { ...@@ -21,7 +22,7 @@ namespace op {
struct roialign struct roialign
{ {
std::string coord_trans_mode = "half_pixel"; std::string coord_trans_mode = "half_pixel";
std::string mode = "avg"; pooling_mode mode = {pooling_mode::average};
int64_t output_height = 1; int64_t output_height = 1;
int64_t output_width = 1; int64_t output_width = 1;
int64_t sampling_ratio = 0; int64_t sampling_ratio = 0;
...@@ -42,7 +43,7 @@ struct roialign ...@@ -42,7 +43,7 @@ struct roialign
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3).standard(); check_shapes{inputs, *this}.has(3);
auto x_lens = inputs.at(0).lens(); auto x_lens = inputs.at(0).lens();
auto roi_lens = inputs.at(1).lens(); auto roi_lens = inputs.at(1).lens();
auto bi_lens = inputs.at(2).lens(); auto bi_lens = inputs.at(2).lens();
...@@ -241,19 +242,19 @@ struct roialign ...@@ -241,19 +242,19 @@ struct roialign
in_dims[0] * in_dims[1]); in_dims[0] * in_dims[1]);
double output_val; double output_val;
std::tie(output_val, vec_index[c]) = std::tie(output_val, vec_index[c]) =
(mode == "avg") ? this->calc_pooling(offset_bottom_data, (mode == migraphx::op::pooling_mode::average)
bin_grid_size, ? this->calc_pooling(offset_bottom_data,
pre_calc, bin_grid_size,
vec_index[c], pre_calc,
avg_pool{}) vec_index[c],
: this->calc_pooling(offset_bottom_data, avg_pool{})
bin_grid_size, : this->calc_pooling(offset_bottom_data,
pre_calc, bin_grid_size,
vec_index[c], pre_calc,
max_pool{}); vec_index[c],
max_pool{});
output(n, c, ph, pw) = output_val; output(n, c, ph, pw) = output_val;
}); });
}); });
}); });
......
...@@ -40,7 +40,6 @@ struct scalar ...@@ -40,7 +40,6 @@ struct scalar
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_add : scatternd_op<scatternd_add>
{
scatternd_add() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x += y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_mul : scatternd_op<scatternd_mul>
{
scatternd_mul() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x *= y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_none : scatternd_op<scatternd_none>
{
scatternd_none() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x = y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
template <class Derived>
struct scatternd_op : op_name<Derived>
{
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3);
auto r = inputs.front().lens().size();
auto q = inputs.at(1).lens().size();
auto k = inputs.at(1).lens().back();
auto ind_lens = inputs.at(1).lens();
auto upd_lens = inputs.back().lens();
auto data_lens = inputs.front().lens();
if(k > r)
MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) +
" is too large for tensor of rank " + std::to_string(r));
if(not(std::equal(ind_lens.begin(), ind_lens.begin() + q - 1, upd_lens.begin()) and
std::equal(data_lens.begin() + k, data_lens.end(), upd_lens.begin() + q - 1)))
MIGRAPHX_THROW("ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] "
"++ data.lens[k:r-1]");
auto s = inputs.front();
if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto& self = static_cast<const Derived&>(*this);
visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) {
std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) {
auto updates_shape = updates.get_shape();
auto updates_std = shape{updates_shape.type(), updates_shape.lens()};
auto indices_shape = indices.get_shape();
auto k = indices_shape.lens().back();
auto q = indices_shape.lens().size();
auto r = output_shape.lens().size();
par_for(updates_shape.elements(), [&](const auto i) {
auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0);
std::copy(
updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin());
auto index_start = indices.begin() +
indices_shape.index(indices_idx.begin(), indices_idx.end());
auto index_end = index_start + k;
std::vector<std::size_t> out_idx(r, 0);
std::copy(index_start, index_end, out_idx.begin());
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[output_shape.index(out_idx)], updates[i]);
});
});
});
return result;
}
auto init() const {}
scatternd_op() {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -37,48 +37,53 @@ struct squeeze ...@@ -37,48 +37,53 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; })) if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty()) if(axes.empty())
{ {
std::copy_if(old_lens.begin(), for(auto i : range(old_lens.size()))
old_lens.end(), {
std::back_inserter(new_lens), if(old_lens[i] != 1)
[](auto len) { return len != 1; }); {
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
} }
else else
{ {
for(std::size_t i = 0; i < old_lens.size(); i++) for(auto i : range(old_lens.size()))
{ {
if(std::find(axes.begin(), axes.end(), i) == axes.end()) if(std::find(axes.begin(), axes.end(), i) == axes.end())
{ {
new_lens.push_back(old_lens[i]); new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
} }
} }
} }
if(new_lens.empty()) if(new_lens.empty())
{ {
return shape{type}; return shape{type};
} }
else else
{ {
return shape{type, new_lens}; return shape{type, new_lens, new_strides};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -72,8 +72,6 @@ struct step ...@@ -72,8 +72,6 @@ struct step
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -56,7 +56,6 @@ struct transpose ...@@ -56,7 +56,6 @@ struct transpose
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -64,7 +64,6 @@ struct unary : op_name<Derived> ...@@ -64,7 +64,6 @@ struct unary : op_name<Derived>
input.end(), input.end(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
}); });
}); });
return result; return result;
......
...@@ -37,11 +37,11 @@ struct unsqueeze ...@@ -37,11 +37,11 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard_or_scalar(); check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar()) if(input_shape.scalar())
{ {
if(old_lens.size() == 1 and old_lens.front() == 1) if(old_lens.size() == 1 and old_lens.front() == 1)
...@@ -53,25 +53,34 @@ struct unsqueeze ...@@ -53,25 +53,34 @@ struct unsqueeze
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0; std::size_t p = 0;
for(std::size_t i = 0; i < new_size; i++) for(auto i : range(new_size))
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) if(std::find(axes.begin(), axes.end(), i) != axes.end())
{ {
new_lens[i] = 1; new_lens[i] = 1;
if(p == 0) // unsqueeze on the first axes
{
new_strides[i] = old_lens[0] * old_strides[0];
}
else // unsqueeze on middle or last axes
{
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1;
}
} }
else else
{ {
new_lens[i] = old_lens[p++]; new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
} }
} }
return shape{type, new_lens}; return shape{type, new_lens, new_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -445,35 +445,62 @@ lifetime get_lifetime_op(const T&) ...@@ -445,35 +445,62 @@ lifetime get_lifetime_op(const T&)
} // namespace detail } // namespace detail
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct operation struct operation
* { {
* std::string name() const; //
* bool is_context_free() const; std::string name() const;
* bool need_normalization() const; // (optional)
* bool has_finalize() const; bool is_context_free() const;
* lifetime get_lifetime() const; // (optional)
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const; bool need_normalization() const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ; // (optional)
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ; bool has_finalize() const;
* shape compute_shape(const std::vector<shape>& input) const; // (optional)
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>& lifetime get_lifetime() const;
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>& // (optional)
* input) const; argument compute(const shape& output,const std::vector<argument>& input) std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* const; argument compute(const shape& output,const std::vector<argument>& input,const // (optional)
* std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const value compile(context& ctx, const shape& output, const std::vector<shape>& input);
* std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const // (optional)
* shape& output,const std::vector<argument>& input,const std::vector<module_ref>& void finalize(context& ctx, const shape& output, const std::vector<shape>& input);
* module_args,std::function<std::vector<argument>(module_ref&, const // (optional)
* std::unordered_map<std::string, argument>&)> run) const; value to_value() const; void shape compute_shape(const std::vector<shape>& input) const;
* from_value(const value& v) ; value attributes() const; friend std::ostream & // (optional)
* operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation & shape compute_shape(const std::vector<shape>& inputs,
* x,const operation & y) ; const std::vector<module_ref>& mod_args) const;
* }; // (optional)
* argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
*/ // (optional)
argument compute(const shape& output, const std::vector<argument>& input) const;
// (optional)
argument compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const;
// (optional)
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const;
// (optional)
value to_value() const;
// (optional)
void from_value(const value& v);
// (optional)
value attributes() const;
//
friend std::ostream& operator<<(std::ostream& os, const operation& op);
//
friend bool operator==(const operation& x, const operation& y);
};
#else
struct operation struct operation
{ {
...@@ -1222,6 +1249,7 @@ inline const ValueType& any_cast(const operation& x) ...@@ -1222,6 +1249,7 @@ inline const ValueType& any_cast(const operation& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); } inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
......
...@@ -41,6 +41,7 @@ ...@@ -41,6 +41,7 @@
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp> #include <migraphx/op/im2col.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/leaky_relu.hpp> #include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp> #include <migraphx/op/less.hpp>
#include <migraphx/op/load.hpp> #include <migraphx/op/load.hpp>
...@@ -86,6 +87,9 @@ ...@@ -86,6 +87,9 @@
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp> #include <migraphx/op/scatter.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp> #include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
......
...@@ -41,7 +41,6 @@ auto par_dfor(Ts... xs) ...@@ -41,7 +41,6 @@ auto par_dfor(Ts... xs)
{ {
dfor(xs...)(f); dfor(xs...)(f);
} }
}; };
} }
......
...@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f) ...@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
template <class F> template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f) void par_for(std::size_t n, std::size_t min_grain, F f)
{ {
const auto threadsize = const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain); n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f); par_for_impl(n, threadsize, f);
} }
......
...@@ -58,17 +58,20 @@ void module_pass_manager_apply(const T& x, module_pass_manager& mpm) ...@@ -58,17 +58,20 @@ void module_pass_manager_apply(const T& x, module_pass_manager& mpm)
} // namespace detail } // namespace detail
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct pass struct pass
* { {
* std::string name() const; //
* void apply(module_pass_manager & mpm) const; std::string name() const;
* void apply(program & p) const; // (optional)
* }; void apply(module_pass_manager& mpm) const;
* // (optional)
*/ void apply(program& p) const;
};
#else
struct pass struct pass
{ {
...@@ -303,6 +306,7 @@ inline const ValueType& any_cast(const pass& x) ...@@ -303,6 +306,7 @@ inline const ValueType& any_cast(const pass& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -81,6 +81,9 @@ struct program ...@@ -81,6 +81,9 @@ struct program
const std::function<void(instruction_ref, const std::function<void(instruction_ref,
std::unordered_map<instruction_ref, std::string>)>& std::unordered_map<instruction_ref, std::string>)>&
print_func) const; print_func) const;
void print(const std::function<void(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>)>&
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
......
...@@ -26,30 +26,35 @@ struct schedule_model ...@@ -26,30 +26,35 @@ struct schedule_model
/// Get the number of concurrent instruction allowed /// Get the number of concurrent instruction allowed
std::size_t concurrency() const; std::size_t concurrency() const;
/// Schedule a concurrent instruction /// Schedule a concurrent instruction
void sched(module& p, instruction_ref ins, std::size_t n) const; void sched(module& m, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction // Insert necessary waits before an instruction
void wait(module& p, instruction_ref ins, std::size_t wait_id) const; void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction // Insert necessary records after an instruction
void record(module& p, instruction_ref ins, std::size_t wait_id) const; void record(module& m, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation /// Compute weights for an operation
std::size_t weight(const operation& op) const; std::size_t weight(const operation& op) const;
}; };
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct schedule_model struct schedule_model
* { {
* std::size_t concurrency() const; //
* void sched(module& p,instruction_ref ins,std::size_t n) const; std::size_t concurrency() const;
* void wait(module& p,instruction_ref ins,std::size_t wait_id) const; //
* void record(module& p,instruction_ref ins,std::size_t wait_id) const; void sched(module& m, instruction_ref ins, std::size_t n) const;
* std::size_t weight(const operation& op) const; //
* }; void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
* //
*/ void record(module& m, instruction_ref ins, std::size_t wait_id) const;
//
std::size_t weight(const operation& op) const;
};
#else
struct schedule_model struct schedule_model
{ {
...@@ -120,22 +125,22 @@ struct schedule_model ...@@ -120,22 +125,22 @@ struct schedule_model
return (*this).private_detail_te_get_handle().concurrency(); return (*this).private_detail_te_get_handle().concurrency();
} }
void sched(module& p, instruction_ref ins, std::size_t n) const void sched(module& m, instruction_ref ins, std::size_t n) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().sched(p, ins, n); (*this).private_detail_te_get_handle().sched(m, ins, n);
} }
void wait(module& p, instruction_ref ins, std::size_t wait_id) const void wait(module& m, instruction_ref ins, std::size_t wait_id) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait(p, ins, wait_id); (*this).private_detail_te_get_handle().wait(m, ins, wait_id);
} }
void record(module& p, instruction_ref ins, std::size_t wait_id) const void record(module& m, instruction_ref ins, std::size_t wait_id) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().record(p, ins, wait_id); (*this).private_detail_te_get_handle().record(m, ins, wait_id);
} }
std::size_t weight(const operation& op) const std::size_t weight(const operation& op) const
...@@ -159,9 +164,9 @@ struct schedule_model ...@@ -159,9 +164,9 @@ struct schedule_model
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::size_t concurrency() const = 0; virtual std::size_t concurrency() const = 0;
virtual void sched(module& p, instruction_ref ins, std::size_t n) const = 0; virtual void sched(module& m, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(module& p, instruction_ref ins, std::size_t wait_id) const = 0; virtual void wait(module& m, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(module& p, instruction_ref ins, std::size_t wait_id) const = 0; virtual void record(module& m, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0; virtual std::size_t weight(const operation& op) const = 0;
}; };
...@@ -195,22 +200,22 @@ struct schedule_model ...@@ -195,22 +200,22 @@ struct schedule_model
std::size_t concurrency() const override { return private_detail_te_value.concurrency(); } std::size_t concurrency() const override { return private_detail_te_value.concurrency(); }
void sched(module& p, instruction_ref ins, std::size_t n) const override void sched(module& m, instruction_ref ins, std::size_t n) const override
{ {
private_detail_te_value.sched(p, ins, n); private_detail_te_value.sched(m, ins, n);
} }
void wait(module& p, instruction_ref ins, std::size_t wait_id) const override void wait(module& m, instruction_ref ins, std::size_t wait_id) const override
{ {
private_detail_te_value.wait(p, ins, wait_id); private_detail_te_value.wait(m, ins, wait_id);
} }
void record(module& p, instruction_ref ins, std::size_t wait_id) const override void record(module& m, instruction_ref ins, std::size_t wait_id) const override
{ {
private_detail_te_value.record(p, ins, wait_id); private_detail_te_value.record(m, ins, wait_id);
} }
std::size_t weight(const operation& op) const override std::size_t weight(const operation& op) const override
...@@ -283,6 +288,7 @@ inline const ValueType& any_cast(const schedule_model& x) ...@@ -283,6 +288,7 @@ inline const ValueType& any_cast(const schedule_model& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -35,7 +35,7 @@ struct shape ...@@ -35,7 +35,7 @@ struct shape
m(int64_type, int64_t) \ m(int64_type, int64_t) \
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) m(uint64_type, uint64_t)
// clang-format on // clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
...@@ -131,6 +131,8 @@ struct shape ...@@ -131,6 +131,8 @@ struct shape
shape with_lens(type_t t, const std::vector<std::size_t>& l) const; shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
shape with_lens(const std::vector<std::size_t>& l) const; shape with_lens(const std::vector<std::size_t>& l) const;
shape with_type(type_t t) const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x); friend std::ostream& operator<<(std::ostream& os, const shape& x);
...@@ -190,8 +192,7 @@ struct shape ...@@ -190,8 +192,7 @@ struct shape
{ {
switch(t) switch(t)
{ {
case tuple_type: case tuple_type: {
{
tv(); tv();
return; return;
} }
...@@ -228,10 +229,11 @@ struct shape ...@@ -228,10 +229,11 @@ struct shape
const std::vector<shape>& sub_shapes() const; const std::vector<shape>& sub_shapes() const;
std::size_t element_space() const;
private: private:
shape(std::shared_ptr<shape_impl> pimpl);
std::shared_ptr<const shape_impl> impl; std::shared_ptr<const shape_impl> impl;
std::size_t element_space() const;
}; };
void migraphx_to_value(value& v, const shape& s); void migraphx_to_value(value& v, const shape& s);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment