Unverified Commit 157935ff authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Conditionally enable pointwise fusion (#992)

This enables the pointwise fusions using the MIGRAPHX_ENABLE_POINTWISE_FUSION env variable. Its disabled by default since MIOpen fusions need to be refactored.

This also adds a compile_ops pass to compile the pointwise modules. All tests except test_gpu_fast_math passes with MIGRAPHX_ENABLE_POINTWISE_FUSION=1 set.
parent 38287064
......@@ -60,9 +60,9 @@ TEST_CASE(single)
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = add_pointwise(p2, "pointwise0", {x, y}, single_pointwise("add"));
auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "pointwise1", {pass, z}, single_pointwise("add"));
auto add2 = add_pointwise(p2, "main:pointwise1", {pass, z}, single_pointwise("add"));
mm->add_return({add2});
}
EXPECT(p1 == p2);
......@@ -84,14 +84,15 @@ TEST_CASE(double_add)
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto fadd = add_pointwise(p2, "pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto fadd =
add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
mm->add_return({fadd});
}
EXPECT(p1.sort() == p2.sort());
......@@ -117,10 +118,10 @@ TEST_CASE(used_twice_not_fused)
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "pointwise0", {x, y}, single_pointwise("add"));
auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1);
auto fadd =
add_pointwise(p2, "pointwise1", {add1, y, pass}, [=](auto* pm, const auto& inputs) {
auto fadd = add_pointwise(
p2, "main:pointwise1", {add1, y, pass}, [=](auto* pm, const auto& inputs) {
auto add2 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), inputs[2], add2);
});
......@@ -149,7 +150,7 @@ TEST_CASE(used_twice_fused)
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fadd = add_pointwise(p2, "pointwise0", {x, y}, [=](auto* pm, const auto& inputs) {
auto fadd = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[0]);
auto add3 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[1]);
......@@ -179,11 +180,11 @@ TEST_CASE(duplicate_inputs)
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[0]);
});
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "pointwise1", {pass, y}, single_pointwise("add"));
auto add2 = add_pointwise(p2, "main:pointwise1", {pass, y}, single_pointwise("add"));
mm->add_return({add2});
}
EXPECT(p1.sort() == p2.sort());
......@@ -207,7 +208,35 @@ TEST_CASE(scalar_input)
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add1 = add_pointwise(p2, "pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto y = pm->add_literal(1.0f);
return pm->add_instruction(migraphx::make_op("add"), inputs[0], y);
});
mm->add_return({add1});
}
EXPECT(p1 == p2);
}
TEST_CASE(contiguous_input)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto one = mm->add_literal(1.0f);
auto yb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), one);
auto y = mm->add_instruction(migraphx::make_op("contiguous"), yb);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add1});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto y = pm->add_literal(1.0f);
return pm->add_instruction(migraphx::make_op("add"), inputs[0], y);
});
......@@ -216,4 +245,32 @@ TEST_CASE(scalar_input)
EXPECT(p1 == p2);
}
TEST_CASE(all_scalar_input)
{
migraphx::shape s{migraphx::shape::float_type};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add1});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) {
return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
});
mm->add_return({add1});
}
EXPECT(p1.get_output_shapes().size() == 1);
EXPECT(p1.get_output_shapes().front().scalar());
EXPECT(p1.get_output_shapes() == p2.get_output_shapes());
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/stringutils.hpp>
#include <test.hpp>
TEST_CASE(interpolate_string_simple1)
{
std::string input = "Hello ${w}!";
auto s = migraphx::interpolate_string(input, {{"w", "world"}});
EXPECT(s == "Hello world!");
}
TEST_CASE(interpolate_string_simple2)
{
std::string input = "${hello}";
auto s = migraphx::interpolate_string(input, {{"hello", "bye"}});
EXPECT(s == "bye");
}
TEST_CASE(interpolate_string_unbalanced)
{
std::string input = "${hello";
EXPECT(test::throws([&] { migraphx::interpolate_string(input, {{"hello", "bye"}}); }));
}
TEST_CASE(interpolate_string_extra_space)
{
std::string input = "${ hello }";
auto s = migraphx::interpolate_string(input, {{"hello", "bye"}});
EXPECT(s == "bye");
}
TEST_CASE(interpolate_string_multiple)
{
std::string input = "${h} ${w}!";
auto s = migraphx::interpolate_string(input, {{"w", "world"}, {"h", "Hello"}});
EXPECT(s == "Hello world!");
}
TEST_CASE(interpolate_string_next)
{
std::string input = "${hh}${ww}!";
auto s = migraphx::interpolate_string(input, {{"ww", "world"}, {"hh", "Hello"}});
EXPECT(s == "Helloworld!");
}
TEST_CASE(interpolate_string_dollar_sign)
{
std::string input = "$hello";
auto s = migraphx::interpolate_string(input, {{"hello", "bye"}});
EXPECT(s == "$hello");
}
TEST_CASE(interpolate_string_missing)
{
std::string input = "${hello}";
EXPECT(test::throws([&] { migraphx::interpolate_string(input, {{"h", "bye"}}); }));
}
TEST_CASE(interpolate_string_custom1)
{
std::string input = "****{{a}}****";
auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{", "}}");
EXPECT(s == "****b****");
}
TEST_CASE(interpolate_string_custom2)
{
std::string input = "****{{{a}}}****";
auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{{", "}}}");
EXPECT(s == "****b****");
}
TEST_CASE(interpolate_string_custom3)
{
std::string input = "****{{{{a}}}}****";
auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{{{", "}}}}");
EXPECT(s == "****b****");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -103,7 +103,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators
template <class T>
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
auto normalize_compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
{
dependent_type<operation, T> y = x;
......@@ -111,6 +111,13 @@ auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& i
return any_cast<T>(y).normalize_compute_shape(inputs);
}
template <class T>
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.compute_shape(inputs, {}))
{
return x.compute_shape(inputs, {});
}
template <class T>
shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{
......@@ -121,7 +128,7 @@ shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
template <class T>
shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs)
{
return normalize_compute_shape_op(rank<1>{}, x, inputs);
return normalize_compute_shape_op(rank<2>{}, x, inputs);
}
template <class T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment