Commit 7e297b13 authored by Paul's avatar Paul
Browse files

Merge

parents 86ea5e91 aa7ff911
......@@ -35,12 +35,14 @@
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp>
#include <migraphx/op/load.hpp>
......@@ -49,6 +51,7 @@
#include <migraphx/op/logical_or.hpp>
#include <migraphx/op/logical_xor.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/loop.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/max.hpp>
......@@ -56,6 +59,8 @@
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/neg.hpp>
#include <migraphx/op/nonmaxsuppression.hpp>
#include <migraphx/op/nonzero.hpp>
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
......@@ -78,10 +83,16 @@
#include <migraphx/op/rnn_last_hs_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/roialign.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/scatter_add.hpp>
#include <migraphx/op/scatter_mul.hpp>
#include <migraphx/op/scatter_none.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp>
......@@ -95,11 +106,13 @@
#include <migraphx/op/sub.hpp>
#include <migraphx/op/tanh.hpp>
#include <migraphx/op/tan.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/where.hpp>
#endif
......@@ -3,7 +3,10 @@
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_OPTIONAL 1
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#elif defined(__has_include)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
......
......@@ -41,7 +41,6 @@ auto par_dfor(Ts... xs)
{
dfor(xs...)(f);
}
};
}
......
......@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize =
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain);
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f);
}
......
......@@ -8,12 +8,14 @@
#include <utility>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
struct module_pass_manager;
#ifdef DOXYGEN
......@@ -24,6 +26,7 @@ struct pass
/// A unique name used to identify the pass
std::string name() const;
/// Run the pass on the module
void apply(module_pass_manager& mpm) const;
void apply(module& m) const;
/// Run the pass on the program
void apply(program& p) const;
......@@ -31,17 +34,44 @@ struct pass
#else
/*
* Type-erased interface for:
*
* struct pass
* {
* std::string name() const;
* void apply(module & m) const;
* void apply(program & p) const;
* };
*
*/
module& get_module(module_pass_manager& mpm);
namespace detail {
template <class T>
auto module_pass_manager_apply(rank<1>, const T& x, module_pass_manager& mpm)
-> decltype(x.apply(get_module(mpm)))
{
return x.apply(get_module(mpm));
}
template <class T>
void module_pass_manager_apply(rank<0>, const T&, module_pass_manager&)
{
}
template <class T>
void module_pass_manager_apply(const T& x, module_pass_manager& mpm)
{
module_pass_manager_apply(rank<1>{}, x, mpm);
}
} // namespace detail
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct pass
{
//
std::string name() const;
// (optional)
void apply(module_pass_manager& mpm) const;
// (optional)
void apply(program& p) const;
};
#else
struct pass
{
......@@ -112,10 +142,10 @@ struct pass
return (*this).private_detail_te_get_handle().name();
}
void apply(module& m) const
void apply(module_pass_manager& mpm) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().apply(m);
(*this).private_detail_te_get_handle().apply(mpm);
}
void apply(program& p) const
......@@ -137,22 +167,24 @@ struct pass
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual void apply(module& m) const = 0;
virtual void apply(program& p) const = 0;
virtual std::string name() const = 0;
virtual void apply(module_pass_manager& mpm) const = 0;
virtual void apply(program& p) const = 0;
};
template <class T>
static auto private_detail_te_default_apply(char, T&& private_detail_te_self, module& m)
-> decltype(private_detail_te_self.apply(m))
static auto
private_detail_te_default_apply(char, T&& private_detail_te_self, module_pass_manager& mpm)
-> decltype(private_detail_te_self.apply(mpm))
{
private_detail_te_self.apply(m);
private_detail_te_self.apply(mpm);
}
template <class T>
static void private_detail_te_default_apply(float, T&& private_detail_te_self, module& m)
static void
private_detail_te_default_apply(float, T&& private_detail_te_self, module_pass_manager& mpm)
{
migraphx::nop(private_detail_te_self, m);
migraphx::detail::module_pass_manager_apply(private_detail_te_self, mpm);
}
template <class T>
......@@ -198,10 +230,10 @@ struct pass
std::string name() const override { return private_detail_te_value.name(); }
void apply(module& m) const override
void apply(module_pass_manager& mpm) const override
{
private_detail_te_default_apply(char(0), private_detail_te_value, m);
private_detail_te_default_apply(char(0), private_detail_te_value, mpm);
}
void apply(program& p) const override
......@@ -274,6 +306,7 @@ inline const ValueType& any_cast(const pass& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP
#include <list>
#include <unordered_map>
#include <migraphx/operation.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/tracer.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <iostream>
#include <migraphx/pass.hpp>
#include <migraphx/tracer.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager
{
module_pass_manager() = default;
module_pass_manager(const module_pass_manager&) = delete;
virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0;
virtual void run_pass(const pass& p) = 0;
protected:
virtual ~module_pass_manager() {}
};
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace = tracer{});
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace = tracer{});
......
......@@ -23,6 +23,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl;
struct marker;
/**
* @brief Stores the instruction stream
*/
......@@ -65,7 +67,10 @@ struct program
void finalize();
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
void
perf_report(std::ostream& os, std::size_t n, parameter_map params, std::size_t batch = 1) const;
void mark(const parameter_map& params, marker&& m);
value to_value() const;
void from_value(const value& v);
......@@ -76,6 +81,9 @@ struct program
const std::function<void(instruction_ref,
std::unordered_map<instruction_ref, std::string>)>&
print_func) const;
void print(const std::function<void(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>)>&
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const;
void print_cpp(std::ostream& os) const;
......
......@@ -15,7 +15,7 @@ struct module;
struct propagate_constant
{
std::string name() const { return "propagate_constant"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -17,32 +17,10 @@ struct program;
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names = {"all"});
// insert the capture operator for the inputs of each operator to be quantized
// to int8
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names);
template <class T>
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, T&& t, const std::vector<std::string>& ins_names)
{
static_assert(std::is_same<std::remove_cv_t<std::remove_reference_t<T>>, target>{} &&
std::is_lvalue_reference<T>{},
"Dangling reference to target!");
return capture_arguments_impl(prog, t, ins_names);
}
void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#define MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <vector>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* Decompose operators.
* quantize a program to fp16
*/
struct remap
struct quantize_fp16_pass
{
std::string name() const { return "remap"; }
void apply(module& p) const;
std::vector<std::string> ins_names = {"all"};
std::string name() const { return "quantize_fp16"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#include <string>
#include <vector>
#include <functional>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* capture inputs of operators to be quantized to int8
*/
struct capture_arguments_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::function<void(std::size_t, std::vector<argument>)> f{};
std::size_t* param_index = nullptr;
std::string name() const { return "capture_arguments"; }
void apply(module& m) const;
};
/**
* quantize a program to int8
*/
struct quantize_int8_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::vector<std::pair<float, float>> quant_params;
std::string name() const { return "quantize_int8"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -207,8 +207,7 @@ auto visit_all_pack(const shape& s, V1&& v1)
template <class T, class... Ts>
auto visit_all(T&& x, Ts&&... xs)
{
auto&& s = x.get_shape();
// cppcheck-suppress redundantInitialization
auto&& s = x.get_shape();
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
......
......@@ -16,7 +16,7 @@ struct module;
struct rewrite_batchnorm
{
std::string name() const { return "rewrite_batchnorm"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -15,7 +15,7 @@ struct module;
struct rewrite_pooling
{
std::string name() const { return "rewrite_pooling"; }
void apply(module& prog) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -19,22 +19,22 @@ struct module;
struct rewrite_rnn
{
std::string name() const { return "rewrite_rnn"; }
void apply(module& prog) const;
void apply(module& m) const;
private:
// for vanilla rnn operators
void apply_vanilla_rnn(module& prog, instruction_ref ins) const;
void apply_vanilla_rnn(module& m, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators
void apply_gru(module& prog, instruction_ref ins) const;
void apply_gru(module& m, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
......@@ -44,9 +44,9 @@ struct rewrite_rnn
std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators
void apply_lstm(module& prog, instruction_ref ins) const;
void apply_lstm(module& m, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
......@@ -55,24 +55,23 @@ struct rewrite_rnn
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
bool is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(module& prog,
bool is_variable_seq_lens(const module& m, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
op::rnn_direction dirct) const;
void replace_last_cell_output(module& prog,
void replace_last_cell_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
instruction_ref last_cell_output,
op::rnn_direction dirct) const;
std::size_t
get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const;
std::size_t get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(module& prog,
instruction_ref pad_hidden_states(module& m,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const;
......
#ifndef MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#define MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#include <migraphx/instruction_ref.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/module.hpp>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class LoopModel, class T>
argument run_loop(const LoopModel& model,
T& ctx,
std::vector<argument> args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run)
{
std::vector<std::vector<argument>> results;
// process argu lists
auto iter_num = args.at(0).at<int64_t>();
auto cond = args.at(1).at<bool>();
auto input_num = (args.size() - 2) / 2;
auto dep_num = input_num - 2;
module_ref mod = mods.at(0);
auto param_name_shapes = mod->get_parameter_shapes();
auto param_names = mod->get_parameter_names();
std::vector<argument> dep0(args.begin() + input_num + 1, args.begin() + 2 * input_num);
std::vector<argument> dep1(args.begin() + 2 * input_num, args.begin() + 2 * input_num + 1);
auto ins_outputs = args.back().get_sub_objects();
dep1.insert(dep1.end(), ins_outputs.begin(), ins_outputs.begin() + dep_num);
std::array<std::vector<argument>, 2> loop_carry_deps = {dep0, dep1};
// loop iter argument
std::vector<argument> in_args = {args.at(input_num), dep1.at(0)};
in_args.insert(in_args.end(), args.begin() + 2, args.begin() + input_num);
std::vector<argument> out_args = dep0;
out_args.insert(out_args.end(), ins_outputs.begin() + dep_num, ins_outputs.end());
std::vector<argument> scan_outputs(ins_outputs.begin() + dep_num, ins_outputs.end());
auto out_param_indices = model.get_output_params(*mod);
int64_t iter = 0;
for(iter = 0; iter < iter_num and cond; ++iter)
{
// copy iter num and cond to device memory
model.copy(ctx, iter, in_args.at(0));
model.copy(ctx, cond, in_args.at(1));
// wrap up the inputs and outputs
std::unordered_map<std::string, argument> params;
int input_index = 0;
for(const auto& name : param_names)
{
auto ps = mod->get_parameter_shape(name);
if(ps == shape{})
{
continue;
}
// it is an input parameter
if(not contains(out_param_indices, name))
{
params[name] = in_args.at(input_index++);
}
else
{
auto output_index = out_param_indices[name];
if(output_index > dep_num)
{
const auto& arg = out_args.at(output_index);
assert((iter + 1) * ps.bytes() <= arg.get_shape().bytes());
params[name] = argument(ps, arg.data() + iter * ps.bytes());
}
else
{
params[name] = out_args.at(output_index);
}
}
}
auto mod_args = run(mod, params);
// copy back cond to be used next iteration
model.copy(ctx, mod_args.at(0), cond);
// mod outputs are used as next loop input
std::copy(mod_args.begin(), mod_args.begin() + dep_num + 1, in_args.begin() + 1);
const auto& dep_out = loop_carry_deps[(iter + 1) % 2];
std::copy(dep_out.begin(), dep_out.end(), out_args.begin());
std::vector<argument> mod_scan_outs(mod_args.begin() + 1 + dep_num, mod_args.end());
model.append(mod_scan_outs, scan_outputs, iter);
}
out_args.erase(out_args.begin());
std::copy(in_args.begin() + 2, in_args.end(), out_args.begin());
model.set_zero(ctx, scan_outputs, iter);
return {out_args};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -19,7 +19,7 @@ struct schedule
schedule_model model{};
bool enable = true;
std::string name() const { return "schedule"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -26,30 +26,35 @@ struct schedule_model
/// Get the number of concurrent instruction allowed
std::size_t concurrency() const;
/// Schedule a concurrent instruction
void sched(module& p, instruction_ref ins, std::size_t n) const;
void sched(module& m, instruction_ref ins, std::size_t n) const;
// Insert necessary waits before an instruction
void wait(module& p, instruction_ref ins, std::size_t wait_id) const;
void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
// Insert necessary records after an instruction
void record(module& p, instruction_ref ins, std::size_t wait_id) const;
void record(module& m, instruction_ref ins, std::size_t wait_id) const;
/// Compute weights for an operation
std::size_t weight(const operation& op) const;
};
#else
/*
* Type-erased interface for:
*
* struct schedule_model
* {
* std::size_t concurrency() const;
* void sched(module& p,instruction_ref ins,std::size_t n) const;
* void wait(module& p,instruction_ref ins,std::size_t wait_id) const;
* void record(module& p,instruction_ref ins,std::size_t wait_id) const;
* std::size_t weight(const operation& op) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct schedule_model
{
//
std::size_t concurrency() const;
//
void sched(module& m, instruction_ref ins, std::size_t n) const;
//
void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
//
void record(module& m, instruction_ref ins, std::size_t wait_id) const;
//
std::size_t weight(const operation& op) const;
};
#else
struct schedule_model
{
......@@ -120,22 +125,22 @@ struct schedule_model
return (*this).private_detail_te_get_handle().concurrency();
}
void sched(module& p, instruction_ref ins, std::size_t n) const
void sched(module& m, instruction_ref ins, std::size_t n) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().sched(p, ins, n);
(*this).private_detail_te_get_handle().sched(m, ins, n);
}
void wait(module& p, instruction_ref ins, std::size_t wait_id) const
void wait(module& m, instruction_ref ins, std::size_t wait_id) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait(p, ins, wait_id);
(*this).private_detail_te_get_handle().wait(m, ins, wait_id);
}
void record(module& p, instruction_ref ins, std::size_t wait_id) const
void record(module& m, instruction_ref ins, std::size_t wait_id) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().record(p, ins, wait_id);
(*this).private_detail_te_get_handle().record(m, ins, wait_id);
}
std::size_t weight(const operation& op) const
......@@ -159,9 +164,9 @@ struct schedule_model
virtual const std::type_info& type() const = 0;
virtual std::size_t concurrency() const = 0;
virtual void sched(module& p, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(module& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(module& p, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void sched(module& m, instruction_ref ins, std::size_t n) const = 0;
virtual void wait(module& m, instruction_ref ins, std::size_t wait_id) const = 0;
virtual void record(module& m, instruction_ref ins, std::size_t wait_id) const = 0;
virtual std::size_t weight(const operation& op) const = 0;
};
......@@ -195,22 +200,22 @@ struct schedule_model
std::size_t concurrency() const override { return private_detail_te_value.concurrency(); }
void sched(module& p, instruction_ref ins, std::size_t n) const override
void sched(module& m, instruction_ref ins, std::size_t n) const override
{
private_detail_te_value.sched(p, ins, n);
private_detail_te_value.sched(m, ins, n);
}
void wait(module& p, instruction_ref ins, std::size_t wait_id) const override
void wait(module& m, instruction_ref ins, std::size_t wait_id) const override
{
private_detail_te_value.wait(p, ins, wait_id);
private_detail_te_value.wait(m, ins, wait_id);
}
void record(module& p, instruction_ref ins, std::size_t wait_id) const override
void record(module& m, instruction_ref ins, std::size_t wait_id) const override
{
private_detail_te_value.record(p, ins, wait_id);
private_detail_te_value.record(m, ins, wait_id);
}
std::size_t weight(const operation& op) const override
......@@ -283,6 +288,7 @@ inline const ValueType& any_cast(const schedule_model& x)
throw std::bad_cast();
return *y;
}
#endif
#endif
......
......@@ -50,7 +50,6 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
value result = value::array{};
for(auto&& y : x)
{
auto e = to_value(y);
result.insert(to_value(y));
}
return result;
......
......@@ -35,7 +35,7 @@ struct shape
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
// clang-format on
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
......@@ -131,6 +131,8 @@ struct shape
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
shape with_lens(const std::vector<std::size_t>& l) const;
shape with_type(type_t t) const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x);
......@@ -186,8 +188,7 @@ struct shape
{
switch(t)
{
case tuple_type:
{
case tuple_type: {
tv();
return;
}
......@@ -224,10 +225,11 @@ struct shape
const std::vector<shape>& sub_shapes() const;
std::size_t element_space() const;
private:
shape(std::shared_ptr<shape_impl> pimpl);
std::shared_ptr<const shape_impl> impl;
std::size_t element_space() const;
};
void migraphx_to_value(value& v, const shape& s);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment