Unverified Commit 351007d4 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fusion of pointwise operators (#969)

Adds a pass to fuse pointwise operators into one "pointwsie" op that has a submodule which does the calculation.
parent 77164f3c
......@@ -26,6 +26,7 @@ add_library(migraphx
eliminate_pad.cpp
env.cpp
file_buffer.cpp
fuse_pointwise.cpp
generate.cpp
inline_module.cpp
insert_pad.cpp
......@@ -133,6 +134,7 @@ register_migraphx_ops(
nonzero
outline
pad
pointwise
pooling
pow
prefix_scan_sum
......
......@@ -155,5 +155,13 @@ std::vector<argument> argument::get_sub_objects() const
return result;
}
argument argument::element(std::size_t i) const
{
assert(this->get_shape().sub_shapes().empty());
auto idx = this->get_shape().index(i);
auto offset = this->get_shape().type_size() * idx;
return argument{shape{this->get_shape().type()}, this->data() + offset};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <iterator>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static literal get_scalar(instruction_ref ins)
{
const auto& s = ins->get_shape();
if(not(s.elements() == 1 or s.scalar()))
return {};
if(not ins->can_eval())
return {};
auto e = ins->eval();
literal r{};
e.visit_at([&](auto x) { r = literal{x}; });
return r;
}
static void create_pointwise_modules(module_pass_manager& mpm)
{
std::size_t n = 0;
for(auto ins : iterator_for(mpm.get_module()))
{
if(not ins->get_operator().attributes().get("pointwise", false))
continue;
auto* pm = mpm.create_module("pointwise" + std::to_string(n++));
pm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
std::vector<instruction_ref> pointwise_inputs;
for(auto input : ins->inputs())
{
if(contains(param_map, input))
continue;
auto scalar = get_scalar(input);
if(scalar.empty())
{
pointwise_inputs.push_back(input);
param_map[input] = pm->add_parameter("x" + std::to_string(param_map.size()),
shape{input->get_shape().type()});
}
else
{
param_map[input] = pm->add_literal(scalar);
}
}
std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return param_map[input]; });
auto r = pm->add_instruction(ins->get_operator(), inputs);
pm->add_return({r});
mpm.get_module().replace_instruction(ins, make_op("pointwise"), pointwise_inputs, {pm});
}
}
static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
instruction_ref output)
{
module_ref pm = ins->module_inputs().at(0);
module_ref xm = output->module_inputs().at(0);
auto last = std::prev(pm->end());
assert(last->name() == "@return");
assert(last->inputs().size() == 1);
std::vector<instruction_ref> inputs = ins->inputs();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::unordered_map<instruction_ref, instruction_ref> input_map;
// Copy inputs to input_map
for(auto i : range(inputs.size()))
{
auto input = inputs[i];
auto param = pm->get_parameter("x" + std::to_string(i));
input_map[input] = param;
}
// Add the new parameter and additional inputs
for(auto i : range(output->inputs().size()))
{
auto input = output->inputs()[i];
auto param = xm->get_parameter("x" + std::to_string(i));
if(input == ins)
{
map_ins[param] = last->inputs().front();
input_map[input] = map_ins[param];
}
// Avoid duplicate paramter inputs
else if(contains(input_map, input))
{
map_ins[param] = input_map[input];
}
else
{
map_ins[param] =
pm->add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()});
inputs.push_back(input);
input_map[input] = map_ins[param];
}
}
pm->replace_return(pm->insert_module_instructions(last, xm, map_ins));
return inputs;
}
static bool find_pointwise_modules(module& m)
{
bool changed = false;
for(auto ins : iterator_for(m))
{
if(ins->name() != "pointwise")
continue;
if(ins->outputs().empty())
continue;
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->name() == "pointwise" and i->outputs().size() == 1;
});
if(it == ins->inputs().end())
continue;
auto new_inputs = append_pointwise_module(*it, ins);
m.replace_instruction(*it, (*it)->get_operator(), new_inputs, (*it)->module_inputs());
m.replace_instruction(ins, *it);
m.move_instruction(*it, ins);
changed = true;
}
return changed;
}
void fuse_pointwise::apply(module_pass_manager& mpm) const
{
create_pointwise_modules(mpm);
mpm.run_pass(dead_code_elimination{});
for(int i = 0; i < 8; i++)
{
if(not find_pointwise_modules(mpm.get_module()))
break;
mpm.run_pass(dead_code_elimination{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -67,6 +67,9 @@ struct argument : raw_data<argument>
std::vector<argument> get_sub_objects() const;
/// Return the ith element
argument element(std::size_t i) const;
private:
void assign_buffer(std::function<char*()> d);
struct data_t
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
#include <migraphx/config.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
struct fuse_pointwise
{
std::string name() const { return "fuse_pointwise"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
......@@ -152,30 +152,4 @@ struct instruction
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace std {
template <>
struct hash<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = std::size_t;
result_type operator()(const argument_type& x) const noexcept
{
return std::hash<migraphx::instruction*>{}(&*x);
}
};
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return &*x == &*y;
}
};
} // namespace std
#endif
......@@ -11,7 +11,35 @@ inline namespace MIGRAPHX_INLINE_NS {
struct instruction;
using instruction_ref = std::list<instruction>::iterator;
migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace std {
template <>
struct hash<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = std::size_t;
result_type operator()(const migraphx::instruction_ref& x) const noexcept
{
return std::hash<migraphx::instruction*>{}(migraphx::as_address(x));
}
};
template <>
struct equal_to<migraphx::instruction_ref>
{
using argument_type = migraphx::instruction_ref;
using result_type = bool;
result_type operator()(const migraphx::instruction_ref& x,
const migraphx::instruction_ref& y) const noexcept
{
return migraphx::as_address(x) == migraphx::as_address(y);
}
};
} // namespace std
#endif
......@@ -96,6 +96,11 @@ struct module
instruction_ref move_instruction(instruction_ref src, instruction_ref dst);
instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
std::vector<instruction_ref>
insert_module_instructions(instruction_ref ins,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
{
......@@ -110,6 +115,8 @@ struct module
instruction_ref add_return(std::vector<instruction_ref> args);
instruction_ref replace_return(std::vector<instruction_ref> args);
std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const;
......
#ifndef MIGRAPHX_GUARD_OP_POINTWISE_HPP
#define MIGRAPHX_GUARD_OP_POINTWISE_HPP
#include <migraphx/config.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/module.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct pointwise
{
std::string name() const { return "pointwise"; }
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
if(mods.size() != 1)
{
MIGRAPHX_THROW("should have one submodule.");
}
auto* pm = mods.front();
auto pnames = pm->get_parameter_names();
std::sort(pnames.begin(), pnames.end());
check_shapes{inputs, *this}.has(pnames.size()).same_dims();
for(auto i : range(pnames.size()))
{
auto s1 = pm->get_parameter(pnames[i])->get_shape();
auto s2 = inputs[i];
if(s1.type() != s2.type())
MIGRAPHX_THROW("Mismatch type");
}
if(pm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("submodule should have only one output.");
auto type = pm->get_output_shapes().front().type();
return shape::from_permutation(type, inputs.front().lens(), find_permutation(inputs));
}
argument compute(const shape& output_shape,
const std::vector<argument>& args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
argument output{output_shape};
auto* pm = mods.front();
auto pnames = pm->get_parameter_names();
std::sort(pnames.begin(), pnames.end());
par_for(output_shape.elements(), [&](auto i) {
std::unordered_map<std::string, argument> params;
std::transform(
pnames.begin(),
pnames.end(),
args.begin(),
std::inserter(params, params.end()),
[&](auto&& name, auto&& arg) { return std::make_pair(name, arg.element(i)); });
auto results = run(pm, params);
assert(results.size() == 1);
visit_all(output, results.front())([&](auto out, auto x) { out[i] = x.front(); });
});
return output;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_OP_POINTWISE_HPP
......@@ -11,49 +11,8 @@ inline namespace MIGRAPHX_INLINE_NS {
static void inline_submodule(module& m, instruction_ref ins, bool cond)
{
const auto& mod_inputs = ins->module_inputs();
const auto* smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*smod))
{
instruction_ref copy_ins{};
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
mod_outputs = {copy_ins};
}
module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
auto mod_outputs = m.insert_module_instructions(ins, smod);
auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size());
......
......@@ -468,5 +468,11 @@ std::vector<shape> try_compute_shape(const operation& op, const std::vector<shap
}
return {new_shape};
}
migraphx::instruction* as_address(const instruction_ref& ins) noexcept
{
return std::addressof(*ins);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <iterator>
#include <migraphx/module.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
......@@ -302,6 +303,55 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d
return src;
}
std::vector<instruction_ref> module::insert_module_instructions(
instruction_ref ins, module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*m))
{
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = this->add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = this->add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = this->add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = this->insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty())
mod_outputs = {map_ins.at(std::prev(m->end()))};
return mod_outputs;
}
instruction_ref module::add_literal(literal l)
{
impl->emplace_front(std::move(l));
......@@ -332,6 +382,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args)
return result;
}
instruction_ref module::replace_return(std::vector<instruction_ref> args)
{
auto last = std::prev(this->end());
// If there is no return then add a return
if(last->name() != "@return")
return this->add_return(args);
shape r = compute_shape(last->get_operator(), args);
instruction::replace(last, last->get_operator(), r, std::move(args));
assert(last->valid(begin()));
return last;
}
shape module::get_parameter_shape(std::string name) const
{
auto ins = std::find_if(
......
File mode changed from 100644 to 100755
#include "migraphx/dead_code_elimination.hpp"
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}});
}
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
auto* pm = p.create_module(name);
auto* mm = p.get_main_module();
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
return pm->add_parameter("x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type()});
});
auto r = f(pm, params);
pm->add_return({r});
return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm});
}
auto single_pointwise(const std::string& name)
{
return [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op(name), inputs);
};
}
TEST_CASE(single)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), pass, z);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = add_pointwise(p2, "pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "pointwise1", {pass, z}, single_pointwise("add"));
mm->add_return({add2});
}
EXPECT(p1 == p2);
}
TEST_CASE(double_add)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto fadd = add_pointwise(p2, "pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
mm->add_return({fadd});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(used_twice_not_fused)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, y);
auto add3 = mm->add_instruction(migraphx::make_op("add"), pass, add2);
mm->add_return({add3});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1);
auto fadd =
add_pointwise(p2, "pointwise1", {add1, y, pass}, [=](auto* pm, const auto& inputs) {
auto add2 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), inputs[2], add2);
});
mm->add_return({fadd});
}
EXPECT(p1 == p2);
}
TEST_CASE(used_twice_fused)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, x);
auto add3 = mm->add_instruction(migraphx::make_op("add"), add1, y);
auto add4 = mm->add_instruction(migraphx::make_op("add"), add2, add3);
mm->add_return({add4});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fadd = add_pointwise(p2, "pointwise0", {x, y}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[0]);
auto add3 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add2, add3);
});
mm->add_return({fadd});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(duplicate_inputs)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, x);
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), pass, y);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "pointwise0", {x}, [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[0]);
});
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "pointwise1", {pass, y}, single_pointwise("add"));
mm->add_return({add2});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(scalar_input)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto one = mm->add_literal(1.0f);
auto y =
mm->add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), one);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add1});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add1 = add_pointwise(p2, "pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto y = pm->add_literal(1.0f);
return pm->add_instruction(migraphx::make_op("add"), inputs[0], y);
});
mm->add_return({add1});
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -2865,6 +2865,26 @@ TEST_CASE(pad_test_lowest_half)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pointwise_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}});
auto* pm = p.create_module("pointwise");
auto x1 = pm->add_parameter("x1", {migraphx::shape::float_type});
auto x2 = pm->add_parameter("x2", {migraphx::shape::float_type});
pm->add_instruction(migraphx::make_op("add"), x1, x2);
mm->add_instruction(migraphx::make_op("pointwise"), {l1, l2}, {pm});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pow_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment