Commit d71254c5 authored by Paul's avatar Paul
Browse files

Merge branch 'master' into mem-color-tests

parents 0cbb0368 7d972d2b
......@@ -14,5 +14,5 @@ gpu::target
cpu::target
-----------
.. doxygenstruct:: migraph::cpu::cpu_target
.. doxygenstruct:: migraph::cpu::target
......@@ -6,6 +6,7 @@ add_library(migraph
dead_code_elimination.cpp
eliminate_allocation.cpp
eliminate_contiguous.cpp
eliminate_concat.cpp
fwd_conv_batchnorm_rewrite.cpp
env.cpp
generate.cpp
......
#include <iterator>
#include <migraph/eliminate_concat.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/dfor.hpp>
namespace migraph {
void eliminate_concat::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
// Look for the concat operator
if(ins->name() != concat_opt.name())
continue;
// If any inputs are literals then abort
if(std::any_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) {
return arg->name() == "@literal";
}))
continue;
// We can only do this optimization when concat axis is either the leftmost
// axis OR the sizes to the left of this axis are all equal to 1
// Since we've already checked that the non-axis dimensions are identical
// we only need to check the first input
auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator());
if(concat_op.axis == 0 ||
std::all_of(lens.begin(), lens.begin() + concat_op.axis, [](auto x) { return x == 1; }))
{
// Last input should be an allocation
auto last = ins->inputs().back();
if(last->name() != concat_opt.allocate())
continue;
// Where are the allocations for the tensors to be concatenated?
std::vector<instruction_ref> allocations;
for(auto ins2 = ins->inputs().begin(); ins2 != ins->inputs().end() - 1; ins2++)
{
auto last2 = (*ins2)->inputs().back();
if(last2->name() == concat_opt.allocate())
{
allocations.push_back(last2);
}
}
// Need to sort the allocations, so that we know where to
// insert the "super"-allocation
std::sort(
allocations.begin(), allocations.end(), [&](instruction_ref x, instruction_ref y) {
return std::distance(p.begin(), x) < std::distance(p.begin(), y);
});
// Move "super" allocation to the front
auto first = allocations.front();
auto super = p.move_instruction(last, first);
std::size_t offset = 0;
for(auto x : allocations)
{
migraph::op::load op{x->get_shape(), offset};
// migraph::op::load op{x->get_shape(), 0};
p.replace_instruction(x, op, {super});
offset += x->get_shape().bytes();
}
std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraph::op::identity{}, args);
}
}
}
} // namespace migraph
#ifndef MIGRAPH_GUARD_CONCAT_OPT_HPP
#define MIGRAPH_GUARD_CONCAT_OPT_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraph/operation.hpp>
#include <migraph/operators.hpp>
namespace migraph {
struct program;
#ifdef DOXYGEN
/// An interface for target-dependent optimization for the concat instruction
struct concat_optimization
{
/// The name of the target-dependent concat operator
std::string name() const;
/// A name of the target-dependent allocate operator
std::string allocate() const;
/// Return the target-independent concat operator
op::concat get_concat(const operation& op) const;
};
#else
/*
* Type-erased interface for:
*
* struct concat_optimization
* {
* std::string name() const;
* std::string allocate() const;
* op::concat get_concat(const operation& op) const;
* };
*
*/
struct concat_optimization
{
// Constructors
concat_optimization() = default;
template <typename PrivateDetailTypeErasedT>
concat_optimization(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
concat_optimization& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().name();
}
std::string allocate() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().allocate();
}
op::concat get_concat(const operation& op) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_concat(op);
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual std::string allocate() const = 0;
virtual op::concat get_concat(const operation& op) const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
std::string name() const override { return private_detail_te_value.name(); }
std::string allocate() const override { return private_detail_te_value.allocate(); }
op::concat get_concat(const operation& op) const override
{
return private_detail_te_value.get_concat(op);
}
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const concat_optimization* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(concat_optimization* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(concat_optimization& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const concat_optimization& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
#define MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
#include <migraph/concat_opt.hpp>
namespace migraph {
struct program;
struct eliminate_concat
{
concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
......@@ -69,6 +69,8 @@ struct instruction
static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
static instruction_ref get_output_alias(instruction_ref ins);
private:
// internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args);
......
......@@ -43,6 +43,9 @@ struct operation
* the same the `output` shape.
*/
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
/// An optional method to return which argument the output will alias. If
/// there is no aliased output then -1 can be returned.
int output_alias(const std::vector<shape>& input) const;
/// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op);
......@@ -108,12 +111,32 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
return compute_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
int output_alias_op(rank<0>, const T&, const std::vector<shape>&)
{
return -1;
}
template <class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
-> decltype(x.output_alias(shapes))
{
return x.output_alias(shapes);
}
template <class T>
int output_alias_op(const T& x, const std::vector<shape>& shapes)
{
return output_alias_op(rank<1>{}, x, shapes);
}
/*
* Type-erased interface for:
*
* struct operation
* {
* std::string name() const;
* int output_alias(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
......@@ -185,6 +208,12 @@ struct operation
return (*this).private_detail_te_get_handle().name();
}
int output_alias(const std::vector<shape>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().output_alias(input);
}
shape compute_shape(const std::vector<shape>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -217,6 +246,7 @@ struct operation
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual int output_alias(const std::vector<shape>& input) const = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
......@@ -254,8 +284,15 @@ struct operation
std::string name() const override { return private_detail_te_value.name(); }
int output_alias(const std::vector<shape>& input) const override
{
return output_alias_op(private_detail_te_value, input);
}
shape compute_shape(const std::vector<shape>& input) const override
{
return private_detail_te_value.compute_shape(input);
}
......
......@@ -223,22 +223,6 @@ struct pooling
}
};
struct activation
{
std::string mode;
std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
friend std::ostream& operator<<(std::ostream& os, const activation& op)
{
os << op.name() << ":" << op.mode;
return os;
}
};
struct leaky_relu
{
std::string name() const { return "leaky_relu"; }
......@@ -296,6 +280,7 @@ struct transpose
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct contiguous
......@@ -359,6 +344,7 @@ struct concat
new_lens[axis] = new_dim_axis;
return {type, new_lens};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct slice
......@@ -440,6 +426,7 @@ struct slice
auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
return {std::move(output_shape), [=] { return input.data() + offset; }};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct squeeze
......@@ -487,6 +474,7 @@ struct squeeze
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct unsqueeze
......@@ -525,6 +513,7 @@ struct unsqueeze
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct reshape
......@@ -576,6 +565,7 @@ struct reshape
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct dot
......@@ -613,9 +603,14 @@ struct unary
}
};
struct identity : unary
struct identity
{
std::string name() const { return "identity"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
};
struct abs : unary
......@@ -673,6 +668,11 @@ struct neg : unary
std::string name() const { return "neg"; }
};
struct relu : unary
{
std::string name() const { return "relu"; }
};
struct softmax
{
std::string name() const { return "softmax"; }
......@@ -713,6 +713,7 @@ struct flatten
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct broadcast
{
......@@ -755,6 +756,7 @@ struct broadcast
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct scalar
......@@ -776,6 +778,7 @@ struct scalar
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct binary
......@@ -828,6 +831,7 @@ struct load
{
return {s, args[0].data() + offset};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
struct outline
......
......@@ -161,12 +161,25 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this);
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return op.compute_shape(shapes);
return shapes;
}
instruction_ref instruction::get_output_alias(instruction_ref ins)
{
auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs()));
if(i < 0)
return ins;
return get_output_alias(ins->inputs().at(i));
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
return op.compute_shape(compute_shapes(args));
}
} // namespace migraph
......@@ -6,7 +6,7 @@
#include <migraph/onnx.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/cpu/target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
......@@ -86,7 +86,7 @@ int main(int argc, char const* argv[])
else
{
// CPU target
prog.compile(migraph::cpu::cpu_target{});
prog.compile(migraph::cpu::target{});
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
auto labels = imageset.first;
auto input = imageset.second;
......
......@@ -52,7 +52,7 @@ struct onnx_parser
add_generic_op("Div", op::div{});
add_generic_op("MatMul", op::dot{});
add_generic_op("Mul", op::mul{});
add_generic_op("Relu", op::activation{"relu"});
add_generic_op("Relu", op::relu{});
add_generic_op("Sub", op::sub{});
add_generic_op("Sum", op::add{});
......@@ -62,6 +62,8 @@ struct onnx_parser
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("AveragePool", &onnx_parser::parse_pooling);
add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm);
......@@ -148,7 +150,12 @@ struct onnx_parser
attribute_map attributes,
std::vector<instruction_ref> args)
{
op::pooling op{name == "MaxPool" ? "max" : "average"};
op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
if(starts_with(name, "Global"))
{
auto lens = args.front()->get_shape().lens();
op.lengths = {lens[2], lens[3]};
}
if(contains(attributes, "pads"))
{
copy(attributes["pads"].ints(), op.padding.begin());
......@@ -584,10 +591,15 @@ struct onnx_parser
}
std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(tensor_dims.begin(),
tensor_dims.end(),
std::back_inserter(dims),
[](auto&& d) { return d.dim_value(); });
std::transform(
tensor_dims.begin(), tensor_dims.end(), std::back_inserter(dims), [](auto&& d) {
if(not d.has_dim_value())
{
long default_batch_size = 1; // FIXME
return default_batch_size;
}
return d.dim_value();
});
return {shape_type, dims};
}
};
......
#include <migraph/onnx.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/cpu/target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
......@@ -18,7 +18,7 @@ template <class F>
migraph::argument run_cpu(F f)
{
auto p = f();
p.compile(migraph::cpu::cpu_target{});
p.compile(migraph::cpu::target{});
migraph::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
......
......@@ -195,6 +195,7 @@ void memory_coloring_impl::register_operand_alias()
operand_alias["transpose"] = 0;
operand_alias["flatten"] = 0;
operand_alias["broadcast"] = 0;
operand_alias["identity"] = 0;
operand_alias["reshape"] = 0;
operand_alias["pass"] = 0;
operand_alias["scalar"] = 0;
......
......@@ -2,6 +2,7 @@
#include <migraph/stringutils.hpp>
#include <migraph/instruction.hpp>
#include <migraph/env.hpp>
#include <migraph/ranges.hpp>
#include <migraph/time.hpp>
#include <migraph/iterator_for.hpp>
#include <iostream>
......@@ -280,7 +281,7 @@ void program::compile(const target& t, tracer trace)
{
assert(this->validate() == impl->instructions.end());
this->impl->ctx = t.get_context();
if(not trace.enabled() and enabled(MIGRAPH_TRACE_COMPILE{}))
if(not trace.enabled() or enabled(MIGRAPH_TRACE_COMPILE{}))
trace = tracer{std::cout};
trace(*this);
trace();
......@@ -329,8 +330,11 @@ argument generic_eval(const program& p,
else if(ins->name() == "@param")
{
results.emplace(ins, trace(ins, [&] {
return params.at(
any_cast<builtin::param>(ins->get_operator()).parameter);
auto param_name =
any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name))
MIGRAPH_THROW("Parameter not found: " + param_name);
return params.at(param_name);
}));
}
else if(ins->name() == "@outline")
......
add_library(migraph_cpu
cpu_target.cpp
cpu_lowering.cpp
target.cpp
lowering.cpp
gemm.cpp
)
......
......@@ -6,7 +6,7 @@
namespace migraph {
namespace cpu {
struct cpu_lowering
struct lowering
{
std::string name() const { return "cpu::lowering"; }
void apply(program& p) const;
......
......@@ -7,7 +7,7 @@
namespace migraph {
namespace cpu {
struct cpu_target
struct target
{
std::string name() const;
std::vector<pass> get_passes(migraph::context& ctx) const;
......
#include <migraph/cpu/cpu_lowering.hpp>
#include <migraph/cpu/lowering.hpp>
#include <migraph/instruction.hpp>
#include <migraph/dfor.hpp>
#include <migraph/operators.hpp>
......@@ -606,6 +606,7 @@ struct cpu_apply
apply_map["sin"] = simple_op<cpu_unary<sin_op>>();
apply_map["cos"] = simple_op<cpu_unary<cos_op>>();
apply_map["tan"] = simple_op<cpu_unary<tan_op>>();
apply_map["relu"] = simple_op<cpu_unary<relu_op>>();
apply_map["add"] = simple_op<cpu_binary<add_op>>();
apply_map["sub"] = simple_op<cpu_binary<sub_op>>();
apply_map["mul"] = simple_op<cpu_binary<mul_op>>();
......@@ -619,11 +620,7 @@ struct cpu_apply
init();
for(auto it : iterator_for(*prog))
{
if(it->name() == "activation")
{
apply_activation(it);
}
else if(it->name() == "pooling")
if(it->name() == "pooling")
{
apply_pooling(it);
}
......@@ -647,13 +644,6 @@ struct cpu_apply
prog->replace_instruction(ins, T{op}, ins->inputs());
}
void apply_activation(instruction_ref ins)
{
auto&& op = any_cast<op::activation>(ins->get_operator());
if(op.mode == "relu")
prog->replace_instruction(ins, cpu_unary<relu_op>{}, ins->inputs());
}
void apply_pooling(instruction_ref ins)
{
auto&& op = any_cast<op::pooling>(ins->get_operator());
......@@ -664,7 +654,7 @@ struct cpu_apply
}
};
void cpu_lowering::apply(program& p) const { cpu_apply{&p}.apply(); }
void lowering::apply(program& p) const { cpu_apply{&p}.apply(); }
} // namespace cpu
......
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/cpu/cpu_lowering.hpp>
#include <migraph/cpu/target.hpp>
#include <migraph/cpu/lowering.hpp>
#include <migraph/auto_contiguous.hpp>
namespace migraph {
namespace cpu {
std::string cpu_target::name() const { return "cpu"; }
std::string target::name() const { return "cpu"; }
std::vector<pass> cpu_target::get_passes(migraph::context&) const
std::vector<pass> target::get_passes(migraph::context&) const
{
return {auto_contiguous{}, cpu_lowering{}};
return {auto_contiguous{}, lowering{}};
}
} // namespace cpu
......
......@@ -25,6 +25,7 @@ struct hip_add
std::string name() const { return "gpu::add"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context&, const shape&, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
struct miopen_add
......@@ -33,6 +34,7 @@ struct miopen_add
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment