Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
...@@ -37,48 +37,53 @@ struct squeeze ...@@ -37,48 +37,53 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; })) if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty()) if(axes.empty())
{ {
std::copy_if(old_lens.begin(), for(auto i : range(old_lens.size()))
old_lens.end(), {
std::back_inserter(new_lens), if(old_lens[i] != 1)
[](auto len) { return len != 1; }); {
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
} }
else else
{ {
for(std::size_t i = 0; i < old_lens.size(); i++) for(auto i : range(old_lens.size()))
{ {
if(std::find(axes.begin(), axes.end(), i) == axes.end()) if(std::find(axes.begin(), axes.end(), i) == axes.end())
{ {
new_lens.push_back(old_lens[i]); new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
} }
} }
} }
if(new_lens.empty()) if(new_lens.empty())
{ {
return shape{type}; return shape{type};
} }
else else
{ {
return shape{type, new_lens}; return shape{type, new_lens, new_strides};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -72,8 +72,6 @@ struct step ...@@ -72,8 +72,6 @@ struct step
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -56,7 +56,6 @@ struct transpose ...@@ -56,7 +56,6 @@ struct transpose
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -64,7 +64,6 @@ struct unary : op_name<Derived> ...@@ -64,7 +64,6 @@ struct unary : op_name<Derived>
input.end(), input.end(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
}); });
}); });
return result; return result;
......
...@@ -37,11 +37,11 @@ struct unsqueeze ...@@ -37,11 +37,11 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard_or_scalar(); check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar()) if(input_shape.scalar())
{ {
if(old_lens.size() == 1 and old_lens.front() == 1) if(old_lens.size() == 1 and old_lens.front() == 1)
...@@ -53,25 +53,34 @@ struct unsqueeze ...@@ -53,25 +53,34 @@ struct unsqueeze
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0; std::size_t p = 0;
for(std::size_t i = 0; i < new_size; i++) for(auto i : range(new_size))
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) if(std::find(axes.begin(), axes.end(), i) != axes.end())
{ {
new_lens[i] = 1; new_lens[i] = 1;
if(p == 0) // unsqueeze on the first axes
{
new_strides[i] = old_lens[0] * old_strides[0];
}
else // unsqueeze on middle or last axes
{
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1;
}
} }
else else
{ {
new_lens[i] = old_lens[p++]; new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
} }
} }
return shape{type, new_lens}; return shape{type, new_lens, new_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -445,35 +445,62 @@ lifetime get_lifetime_op(const T&) ...@@ -445,35 +445,62 @@ lifetime get_lifetime_op(const T&)
} // namespace detail } // namespace detail
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct operation struct operation
* { {
* std::string name() const; //
* bool is_context_free() const; std::string name() const;
* bool need_normalization() const; // (optional)
* bool has_finalize() const; bool is_context_free() const;
* lifetime get_lifetime() const; // (optional)
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const; bool need_normalization() const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ; // (optional)
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ; bool has_finalize() const;
* shape compute_shape(const std::vector<shape>& input) const; // (optional)
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>& lifetime get_lifetime() const;
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>& // (optional)
* input) const; argument compute(const shape& output,const std::vector<argument>& input) std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* const; argument compute(const shape& output,const std::vector<argument>& input,const // (optional)
* std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const value compile(context& ctx, const shape& output, const std::vector<shape>& input);
* std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const // (optional)
* shape& output,const std::vector<argument>& input,const std::vector<module_ref>& void finalize(context& ctx, const shape& output, const std::vector<shape>& input);
* module_args,std::function<std::vector<argument>(module_ref&, const // (optional)
* std::unordered_map<std::string, argument>&)> run) const; value to_value() const; void shape compute_shape(const std::vector<shape>& input) const;
* from_value(const value& v) ; value attributes() const; friend std::ostream & // (optional)
* operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation & shape compute_shape(const std::vector<shape>& inputs,
* x,const operation & y) ; const std::vector<module_ref>& mod_args) const;
* }; // (optional)
* argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
*/ // (optional)
argument compute(const shape& output, const std::vector<argument>& input) const;
// (optional)
argument compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const;
// (optional)
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const;
// (optional)
value to_value() const;
// (optional)
void from_value(const value& v);
// (optional)
value attributes() const;
//
friend std::ostream& operator<<(std::ostream& os, const operation& op);
//
friend bool operator==(const operation& x, const operation& y);
};
#else
struct operation struct operation
{ {
...@@ -1222,6 +1249,7 @@ inline const ValueType& any_cast(const operation& x) ...@@ -1222,6 +1249,7 @@ inline const ValueType& any_cast(const operation& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); } inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
......
...@@ -35,12 +35,14 @@ ...@@ -35,12 +35,14 @@
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp> #include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp> #include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp> #include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp> #include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp> #include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp> #include <migraphx/op/im2col.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/leaky_relu.hpp> #include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp> #include <migraphx/op/less.hpp>
#include <migraphx/op/load.hpp> #include <migraphx/op/load.hpp>
...@@ -85,7 +87,12 @@ ...@@ -85,7 +87,12 @@
#include <migraphx/op/round.hpp> #include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp> #include <migraphx/op/scatter_add.hpp>
#include <migraphx/op/scatter_mul.hpp>
#include <migraphx/op/scatter_none.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp> #include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK) #if defined(CPPCHECK)
#define MIGRAPHX_HAS_OPTIONAL 1
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#elif defined(__has_include)
#if __has_include(<optional>) && __cplusplus >= 201703L #if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1 #define MIGRAPHX_HAS_OPTIONAL 1
#else #else
......
...@@ -41,7 +41,6 @@ auto par_dfor(Ts... xs) ...@@ -41,7 +41,6 @@ auto par_dfor(Ts... xs)
{ {
dfor(xs...)(f); dfor(xs...)(f);
} }
}; };
} }
......
...@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f) ...@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
template <class F> template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f) void par_for(std::size_t n, std::size_t min_grain, F f)
{ {
const auto threadsize = const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain); n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f); par_for_impl(n, threadsize, f);
} }
......
...@@ -58,17 +58,20 @@ void module_pass_manager_apply(const T& x, module_pass_manager& mpm) ...@@ -58,17 +58,20 @@ void module_pass_manager_apply(const T& x, module_pass_manager& mpm)
} // namespace detail } // namespace detail
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct pass struct pass
* { {
* std::string name() const; //
* void apply(module_pass_manager & mpm) const; std::string name() const;
* void apply(program & p) const; // (optional)
* }; void apply(module_pass_manager& mpm) const;
* // (optional)
*/ void apply(program& p) const;
};
#else
struct pass struct pass
{ {
...@@ -303,6 +306,7 @@ inline const ValueType& any_cast(const pass& x) ...@@ -303,6 +306,7 @@ inline const ValueType& any_cast(const pass& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -81,6 +81,9 @@ struct program ...@@ -81,6 +81,9 @@ struct program
const std::function<void(instruction_ref, const std::function<void(instruction_ref,
std::unordered_map<instruction_ref, std::string>)>& std::unordered_map<instruction_ref, std::string>)>&
print_func) const; print_func) const;
void print(const std::function<void(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>)>&
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
......
...@@ -15,7 +15,7 @@ struct module; ...@@ -15,7 +15,7 @@ struct module;
struct propagate_constant struct propagate_constant
{ {
std::string name() const { return "propagate_constant"; } std::string name() const { return "propagate_constant"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -207,8 +207,7 @@ auto visit_all_pack(const shape& s, V1&& v1) ...@@ -207,8 +207,7 @@ auto visit_all_pack(const shape& s, V1&& v1)
template <class T, class... Ts> template <class T, class... Ts>
auto visit_all(T&& x, Ts&&... xs) auto visit_all(T&& x, Ts&&... xs)
{ {
auto&& s = x.get_shape(); auto&& s = x.get_shape();
// cppcheck-suppress redundantInitialization
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...}; std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); })) if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
......
...@@ -16,7 +16,7 @@ struct module; ...@@ -16,7 +16,7 @@ struct module;
struct rewrite_batchnorm struct rewrite_batchnorm
{ {
std::string name() const { return "rewrite_batchnorm"; } std::string name() const { return "rewrite_batchnorm"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -15,7 +15,7 @@ struct module; ...@@ -15,7 +15,7 @@ struct module;
struct rewrite_pooling struct rewrite_pooling
{ {
std::string name() const { return "rewrite_pooling"; } std::string name() const { return "rewrite_pooling"; }
void apply(module& prog) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -19,22 +19,22 @@ struct module; ...@@ -19,22 +19,22 @@ struct module;
struct rewrite_rnn struct rewrite_rnn
{ {
std::string name() const { return "rewrite_rnn"; } std::string name() const { return "rewrite_rnn"; }
void apply(module& prog) const; void apply(module& m) const;
private: private:
// for vanilla rnn operators // for vanilla rnn operators
void apply_vanilla_rnn(module& prog, instruction_ref ins) const; void apply_vanilla_rnn(module& m, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward, std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
module& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
operation& actv_func) const; operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const; std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators // for gru operators
void apply_gru(module& prog, instruction_ref ins) const; void apply_gru(module& m, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward, std::vector<instruction_ref> gru_cell(bool is_forward,
module& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int linear_before_reset, int linear_before_reset,
...@@ -44,9 +44,9 @@ struct rewrite_rnn ...@@ -44,9 +44,9 @@ struct rewrite_rnn
std::vector<operation> gru_actv_funcs(instruction_ref ins) const; std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators // for lstm operators
void apply_lstm(module& prog, instruction_ref ins) const; void apply_lstm(module& m, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward, std::vector<instruction_ref> lstm_cell(bool is_forward,
module& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
const operation& actv_func1, const operation& actv_func1,
...@@ -55,24 +55,23 @@ struct rewrite_rnn ...@@ -55,24 +55,23 @@ struct rewrite_rnn
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const; std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
bool is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const; bool is_variable_seq_lens(const module& m, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(module& prog, instruction_ref replace_last_hs_output(module& m,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref last_hs_output, instruction_ref last_hs_output,
op::rnn_direction dirct) const; op::rnn_direction dirct) const;
void replace_last_cell_output(module& prog, void replace_last_cell_output(module& m,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref cell_outputs, instruction_ref cell_outputs,
instruction_ref last_cell_output, instruction_ref last_cell_output,
op::rnn_direction dirct) const; op::rnn_direction dirct) const;
std::size_t std::size_t get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const;
get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(module& prog, instruction_ref pad_hidden_states(module& m,
instruction_ref seq, instruction_ref seq,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref hs) const; instruction_ref hs) const;
......
...@@ -19,7 +19,7 @@ struct schedule ...@@ -19,7 +19,7 @@ struct schedule
schedule_model model{}; schedule_model model{};
bool enable = true; bool enable = true;
std::string name() const { return "schedule"; } std::string name() const { return "schedule"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -26,30 +26,35 @@ struct schedule_model ...@@ -26,30 +26,35 @@ struct schedule_model
/// Get the number of concurrent instruction allowed /// Get the number of concurrent instruction allowed
std::size_t concurrency() const; std::size_t concurrency() const;
/// Schedule a concurrent instruction /// Schedule a concurrent instruction
void sched(module& p, instruction_ref ins, std::size_t n) const; void sched(module& m, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction // Insert necessary waits before an instruction
void wait(module& p, instruction_ref ins, std::size_t wait_id) const; void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction // Insert necessary records after an instruction
void record(module& p, instruction_ref ins, std::size_t wait_id) const; void record(module& m, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation /// Compute weights for an operation
std::size_t weight(const operation& op) const; std::size_t weight(const operation& op) const;
}; };
#else #else
/* #ifdef TYPE_ERASED_DECLARATION
* Type-erased interface for:
* // Type-erased interface for:
* struct schedule_model struct schedule_model
* { {
* std::size_t concurrency() const; //
* void sched(module& p,instruction_ref ins,std::size_t n) const; std::size_t concurrency() const;
* void wait(module& p,instruction_ref ins,std::size_t wait_id) const; //
* void record(module& p,instruction_ref ins,std::size_t wait_id) const; void sched(module& m, instruction_ref ins, std::size_t n) const;
* std::size_t weight(const operation& op) const; //
* }; void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
* //
*/ void record(module& m, instruction_ref ins, std::size_t wait_id) const;
//
std::size_t weight(const operation& op) const;
};
#else
struct schedule_model struct schedule_model
{ {
...@@ -120,22 +125,22 @@ struct schedule_model ...@@ -120,22 +125,22 @@ struct schedule_model
return (*this).private_detail_te_get_handle().concurrency(); return (*this).private_detail_te_get_handle().concurrency();
} }
void sched(module& p, instruction_ref ins, std::size_t n) const void sched(module& m, instruction_ref ins, std::size_t n) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().sched(p, ins, n); (*this).private_detail_te_get_handle().sched(m, ins, n);
} }
void wait(module& p, instruction_ref ins, std::size_t wait_id) const void wait(module& m, instruction_ref ins, std::size_t wait_id) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait(p, ins, wait_id); (*this).private_detail_te_get_handle().wait(m, ins, wait_id);
} }
void record(module& p, instruction_ref ins, std::size_t wait_id) const void record(module& m, instruction_ref ins, std::size_t wait_id) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().record(p, ins, wait_id); (*this).private_detail_te_get_handle().record(m, ins, wait_id);
} }
std::size_t weight(const operation& op) const std::size_t weight(const operation& op) const
...@@ -159,9 +164,9 @@ struct schedule_model ...@@ -159,9 +164,9 @@ struct schedule_model
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::size_t concurrency() const = 0; virtual std::size_t concurrency() const = 0;
virtual void sched(module& p, instruction_ref ins, std::size_t n) const = 0; virtual void sched(module& m, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(module& p, instruction_ref ins, std::size_t wait_id) const = 0; virtual void wait(module& m, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(module& p, instruction_ref ins, std::size_t wait_id) const = 0; virtual void record(module& m, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0; virtual std::size_t weight(const operation& op) const = 0;
}; };
...@@ -195,22 +200,22 @@ struct schedule_model ...@@ -195,22 +200,22 @@ struct schedule_model
std::size_t concurrency() const override { return private_detail_te_value.concurrency(); } std::size_t concurrency() const override { return private_detail_te_value.concurrency(); }
void sched(module& p, instruction_ref ins, std::size_t n) const override void sched(module& m, instruction_ref ins, std::size_t n) const override
{ {
private_detail_te_value.sched(p, ins, n); private_detail_te_value.sched(m, ins, n);
} }
void wait(module& p, instruction_ref ins, std::size_t wait_id) const override void wait(module& m, instruction_ref ins, std::size_t wait_id) const override
{ {
private_detail_te_value.wait(p, ins, wait_id); private_detail_te_value.wait(m, ins, wait_id);
} }
void record(module& p, instruction_ref ins, std::size_t wait_id) const override void record(module& m, instruction_ref ins, std::size_t wait_id) const override
{ {
private_detail_te_value.record(p, ins, wait_id); private_detail_te_value.record(m, ins, wait_id);
} }
std::size_t weight(const operation& op) const override std::size_t weight(const operation& op) const override
...@@ -283,6 +288,7 @@ inline const ValueType& any_cast(const schedule_model& x) ...@@ -283,6 +288,7 @@ inline const ValueType& any_cast(const schedule_model& x)
throw std::bad_cast(); throw std::bad_cast();
return *y; return *y;
} }
#endif
#endif #endif
......
...@@ -50,7 +50,6 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{}) ...@@ -50,7 +50,6 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
value result = value::array{}; value result = value::array{};
for(auto&& y : x) for(auto&& y : x)
{ {
auto e = to_value(y);
result.insert(to_value(y)); result.insert(to_value(y));
} }
return result; return result;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment