Unverified Commit 1a4ff504 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Horizontal fusions of gemms and convolutions (#472)



* Add decompose pass

* Add decompose test

* Formatting

* Add remap

* Formatting

* Add compute method for dot

* Formatting

* Add finder for horizontal fusion

* Formatting

* Formatting

* Reuse predicate

* Add gemm fusions

* Formatting

* Add some fixes for convolution

* Formatting

* Fix shape tests

* Formatting

* Reuse axis equal

* Add initial split fusion

* Formatting

* Update offset

* Workaround outputs that cant accept nonstandard shapes

* Formatting

* Add check for split concat

* Formatting

* Add missing headers

* Formatting

* Add tests

* Formatting

* Add more testing

* Formatting

* Fix when there is duplicate splits in inputs

* Formatting

* Fix mismatch iterators

* Add tests for dot fusions

* Formatting

* Add test for convolution

* Formatting

* Fix tidy issues

* Add more tests

* Formatting

* Ignore build directory for codecov

* Add test for groups

* Formatting

* Add more tests for groups

* Formatting

* Add test for missing end slice

* Add newline

* Remove unused function

* Add support for when beta is not 1

* Formatting

* Add test for scalar

* Add one more scalar test
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 45bb91ea
ignore:
- "test/"
- "src/driver"
- "build/"
......@@ -5,6 +5,7 @@ include(ROCMPackageConfigHelpers)
add_library(migraphx
auto_contiguous.cpp
eliminate_common_subexpression.cpp
decompose.cpp
propagate_constant.cpp
dead_code_elimination.cpp
eliminate_allocation.cpp
......@@ -20,6 +21,7 @@ add_library(migraphx
instruction.cpp
program.cpp
quantization.cpp
remap.cpp
shape.cpp
schedule.cpp
pass_manager.cpp
......
#include <migraphx/decompose.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct find_dot_add
{
auto matcher() const { return match::name("dot")(match::nargs(3)); }
void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = any_cast<op::dot>(ins->get_operator());
if(not float_equal(dot.beta, 1) and
not contains({shape::float_type, shape::half_type, shape::double_type},
ins->get_shape().type()))
return;
auto dot_ins =
p.insert_instruction(ins, op::dot{dot.alpha, 0}, ins->inputs()[0], ins->inputs()[1]);
auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1))
{
auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast =
p.insert_instruction(ins, op::multibroadcast{ins->get_shape().lens()}, beta);
c_ins = p.insert_instruction(ins, op::mul{}, c_ins, beta_broadcast);
}
p.replace_instruction(ins, op::add{}, dot_ins, c_ins);
}
};
} // namespace
void decompose::apply(program& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP
#define MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP
#include <algorithm>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator, class Output, class Predicate>
void group_by(Iterator start, Iterator last, Output out, Predicate pred)
{
while(start != last)
{
auto it = std::partition(start, last, [&](auto x) { return pred(x, *start); });
out(start, it);
start = it;
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Decompose operators.
*/
struct decompose
{
std::string name() const { return "decompose"; }
void apply(program& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/tensor_view.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T, class F>
void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1];
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -9,6 +9,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <cmath>
#include <utility>
......@@ -67,6 +68,18 @@ struct dot
return {t, out_lens};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result;
if(args.size() == 3)
result = args[2];
else
result = argument{output_shape};
visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, alpha, beta); });
return result;
}
};
} // namespace op
......
......@@ -74,6 +74,7 @@ struct program
instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
instruction_ref move_instruction(instruction_ref src, instruction_ref dst);
instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
......@@ -125,6 +126,8 @@ struct program
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
program& sort();
friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#define MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Decompose operators.
*/
struct remap
{
std::string name() const { return "remap"; }
void apply(program& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -114,6 +114,8 @@ struct shape
/// Returns true if all strides are equal to 0 (scalar tensor)
bool scalar() const;
shape normalize_standard() const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x);
......
......@@ -11,6 +11,7 @@
#include <iostream>
#include <sstream>
#include <algorithm>
#include <set>
#include <utility>
namespace migraphx {
......@@ -260,6 +261,14 @@ instruction_ref program::move_instruction(instruction_ref src, instruction_ref d
return src;
}
instruction_ref program::move_instructions(instruction_ref src, instruction_ref dst)
{
this->move_instruction(src, dst);
for(auto ins : src->inputs())
this->move_instruction(ins, src);
return src;
}
instruction_ref program::add_literal(literal l)
{
impl->instructions.emplace_front(std::move(l));
......@@ -796,6 +805,17 @@ void program::annotate(std::ostream& os, std::function<void(instruction_ref)> a)
});
}
program& program::sort()
{
fix([&](auto self, auto ins) {
this->move_instruction(ins, this->begin());
for(auto child : ins->inputs())
self(child);
})(std::prev(this->end()));
assert(this->validate() == this->end());
return *this;
}
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p)
......
#include <migraphx/remap.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/add.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct find_dot_add
{
auto matcher() const
{
return match::name("add")(match::any_of(
match::args(match::name("dot")(match::nargs(2)).bind("dot"), match::any().bind("a")),
match::args(match::used_once().bind("a"),
match::name("dot")(match::nargs(2)).bind("dot"))));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = r.instructions["a"];
auto dot = any_cast<op::dot>(dot_ins->get_operator());
dot.beta = 1;
p.replace_instruction(ins, dot, dot_ins->inputs()[0], dot_ins->inputs()[1], a_ins);
}
};
} // namespace
void remap::apply(program& p) const { match::find_matches(p, find_dot_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -195,6 +195,14 @@ bool shape::scalar() const
bool shape::standard() const { return impl->m_standard; }
shape shape::normalize_standard() const
{
if(this->standard())
return {this->type(), this->lens()};
else
return *this;
}
std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const
......
......@@ -4,7 +4,9 @@
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/neg.hpp>
......@@ -12,6 +14,7 @@
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/algorithm.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -246,6 +249,212 @@ struct find_concat_binary
}
};
std::vector<instruction_ref> get_splits(instruction_ref ins)
{
std::vector<instruction_ref> result;
std::copy_if(ins->outputs().begin(),
ins->outputs().end(),
std::back_inserter(result),
[&](auto i) { return i->name() == "slice"; });
if(result.size() < 2)
return {};
auto get_slice = [](auto& i) -> auto& { return any_cast<op::slice>(i->get_operator()); };
auto&& axes = get_slice(result.front()).axes;
if(std::any_of(result.begin(), result.end(), [&](auto i) { return get_slice(i).axes != axes; }))
return {};
auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts; };
auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends; };
std::sort(
result.begin(), result.end(), [&](auto x, auto y) { return get_start(x) < get_start(y); });
if(std::any_of(get_start(result.front()).begin(), get_start(result.front()).end(), [&](auto i) {
return i != 0;
}))
return {};
auto it = std::adjacent_find(
result.begin(), result.end(), [&](auto x, auto y) { return get_end(x) != get_start(y); });
if(it != result.end())
return {};
for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = axes[i];
if(ins->get_shape().lens()[axis] != get_slice(result.back()).ends[i])
return {};
}
return result;
}
struct find_splits
{
auto matcher() const
{
return match::any(match::any_of[match::outputs()](match::name("slice")(
match::any_of[match::outputs()](match::name("add", "mul", "relu")))));
}
static std::vector<std::vector<instruction_ref>>
get_split_groups(const std::vector<instruction_ref>& splits)
{
std::vector<std::vector<instruction_ref>> groups;
for(auto out : splits.front()->outputs())
{
if(out->name() == "slice")
continue;
std::vector<instruction_ref> group;
for(auto split : splits)
{
auto it =
std::find_if(split->outputs().begin(), split->outputs().end(), [&](auto i) {
return i->get_operator() == out->get_operator();
});
if(it == split->outputs().end())
break;
assert((*it)->name() != "slice");
// If there is a duplicate bail
if(contains(group, *it))
return {};
group.push_back(*it);
}
if(group.size() != splits.size())
continue;
groups.push_back(group);
}
return groups;
}
void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto splits = get_splits(ins);
if(splits.empty())
return;
for(const auto& group : get_split_groups(splits))
{
auto start = group.front();
auto op = start->get_operator();
if(op.name() == "slice")
continue;
// Make sure there is no duplicates
assert(std::none_of(
std::next(group.begin()), group.end(), [&](auto i) { return i == start; }));
auto split_idx = 0;
instruction_ref c = p.end();
if(start->inputs().size() == 1)
{
c = p.insert_instruction(std::next(ins), op, ins);
}
else if(start->inputs().size() == 2)
{
assert(not std::none_of(start->inputs().begin(), start->inputs().end(), [](auto i) {
return i->name() == "slice";
}) && "one argument must be a split");
auto data_idx = 1;
if(start->inputs().back()->name() == "slice")
{
split_idx = 1;
data_idx = 0;
}
std::vector<instruction_ref> data_args;
std::transform(group.begin(),
group.end(),
std::back_inserter(data_args),
[&](auto i) { return i->inputs()[data_idx]; });
// Data arguments must be a constant
if(std::any_of(data_args.begin(), data_args.end(), [](auto i) {
return not i->can_eval();
}))
return;
for(auto data : data_args)
p.move_instructions(data, ins);
auto slice_op = any_cast<op::slice>(splits.front()->get_operator());
assert(not slice_op.axes.empty());
if(slice_op.axes.size() > 1)
return;
auto concat_axis = slice_op.axes.front();
// TODO: Check if axises match
auto concat = p.insert_instruction(ins, op::concat{concat_axis}, data_args);
std::vector<instruction_ref> args;
args.resize(2);
args[split_idx] = ins;
args[data_idx] = concat;
c = p.insert_instruction(std::next(ins), op, args);
}
if(c != p.end())
{
for(auto i : group)
{
auto split = i->inputs()[split_idx];
assert(split->name() == "slice");
// Insert contiguous for reshapes
for(auto output : i->outputs())
{
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
continue;
auto x = p.insert_instruction(output, op::contiguous{}, output->inputs());
p.replace_instruction(output, output->get_operator(), x);
}
p.replace_instruction(i, split->get_operator(), c);
}
}
}
}
};
struct find_split_concat
{
auto matcher() const
{
return match::any(match::any_of[match::outputs()](
match::name("slice")(match::all_of[match::outputs()](match::name("concat")))));
}
void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto splits = get_splits(ins);
if(splits.empty())
return;
if(std::any_of(
splits.begin(), splits.end(), [](auto i) { return i->outputs().size() != 1; }))
return;
// Check for concat operator
auto concat = splits.front()->outputs().front();
if(std::any_of(splits.begin(), splits.end(), [&](auto i) {
return i->outputs().front() != concat;
}))
return;
// Check axis match
auto concat_op = any_cast<op::concat>(concat->get_operator());
auto split_op = any_cast<op::slice>(splits.front()->get_operator());
if(split_op.axes.size() != 1)
return;
if(split_op.axes.front() != concat_op.axis)
return;
// Replace args
auto args = concat->inputs();
auto it =
std::find_if(args.begin(), args.end(), [&](auto i) { return i == splits.front(); });
if(std::distance(it, args.end()) < splits.size())
return;
*it = splits.front()->inputs().front();
args.erase(std::next(it), it + splits.size());
if(args.size() == 1)
p.replace_instruction(concat, args.front());
else
p.replace_instruction(concat, concat->get_operator(), args);
}
};
bool axis_equal(const std::vector<std::size_t>& x,
const std::vector<std::size_t>& y,
std::size_t axis)
......@@ -352,6 +561,83 @@ struct find_add_convs
}
};
MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
{
auto pred = [&](auto name) {
return [=](auto i) {
return i->name() == name and i->inputs().front() == ins and
i->inputs().at(1)->can_eval();
};
};
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return !(dots < 2 and convs < 2);
}
struct find_conv_dot_horiz_fusion
{
auto matcher() const { return horiz_conv_dot(); }
void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto pred = [](auto i, auto j) {
if(i->get_operator() != j->get_operator())
return false;
if(not contains({"dot", "convolution"}, i->name()))
return true;
auto x = i->inputs()[1]->get_shape().lens();
auto y = j->inputs()[1]->get_shape().lens();
if(x.size() != y.size())
return false;
// Check that non-axises match
int axis = 1;
if(i->name() == "dot")
{
axis = x.size() - 1;
}
return axis_equal(x, y, axis);
};
auto each = [&](auto start, auto last) {
if(std::distance(start, last) < 2)
return;
auto&& name = (*start)->name();
if(not contains({"dot", "convolution"}, name))
return;
auto input = (*start)->inputs().front();
std::vector<instruction_ref> args;
std::transform(
start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); });
int axis = 1;
int concat_axis = 0;
if(name == "dot")
{
axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis;
}
for(auto arg : args)
p.move_instructions(arg, input);
// TODO: Check if axises match
auto concat = p.insert_instruction(input, op::concat{concat_axis}, args);
auto fused =
p.insert_instruction(std::next(input), (*start)->get_operator(), input, concat);
int64_t offset = 0;
for(auto arg : range(start, last))
{
int64_t len = arg->get_shape().lens()[axis];
p.replace_instruction(arg, op::slice{{axis}, {offset}, {offset + len}}, fused);
offset += len;
}
};
auto outputs = ins->outputs();
group_by(outputs.begin(), outputs.end(), each, pred);
}
};
struct find_div_const
{
auto matcher() const
......@@ -412,20 +698,23 @@ struct find_rsqrt
void simplify_algebra::apply(program& p) const
{
// Run simplifications multiple times
for(int i = 0; i < 4; i++)
for(int i = 0; i < 8; i++)
{
match::find_matches(p,
find_inner_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_add_convs{},
find_conv_dot_horiz_fusion{},
find_mul_conv{},
find_mul_add{},
find_div_const{},
find_sub_const{},
find_rsqrt{},
find_concat_unary{},
find_concat_binary{});
find_concat_binary{},
find_split_concat{},
find_splits{});
dead_code_elimination{}.apply(p);
}
}
......
......@@ -40,6 +40,7 @@ struct fusion
fusion(const shape& input)
// : fp(make_fusion_plan(input))
{
assert(input.standard());
auto t = make_tensor(input);
fp = make_fusion_plan(t);
keep_alive(std::move(t));
......
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP
#include <migraphx/shape.hpp>
#include <migraphx/gpu/context.hpp>
......
......@@ -38,8 +38,9 @@ Result make_obj(F f, Ts... xs)
return r;
}
inline tensor_descriptor make_tensor(const migraphx::shape& s, bool pack = false)
inline tensor_descriptor make_tensor(const migraphx::shape& os, bool pack = false)
{
auto s = os.normalize_standard();
auto t = make_obj<tensor_descriptor>(&miopenCreateTensorDescriptor);
// Convert to ints
std::vector<int> lens(s.lens().begin(), s.lens().end());
......
......@@ -25,6 +25,8 @@
#include <migraphx/gpu/preallocate_param.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/decompose.hpp>
#include <migraphx/remap.hpp>
#include <migraphx/schedule.hpp>
namespace migraphx {
......@@ -39,6 +41,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
// clang-format off
return
{
decompose{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
......@@ -59,6 +62,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
propagate_constant{},
dead_code_elimination{},
remap{},
dead_code_elimination{},
lowering{&ctx, options.offload_copy},
eliminate_contiguous{},
dead_code_elimination{},
......
#include <migraphx/decompose.hpp>
#include <migraphx/pass_manager.hpp>
#include <basic_ops.hpp>
#include <migraphx/op/abnormal_ops.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::decompose{}}); }
TEST_CASE(dot_add)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{}, x, y, z);
p1.add_instruction(migraphx::op::identity{}, dot);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y);
auto add = p2.add_instruction(migraphx::op::add{}, dot, z);
p2.add_instruction(migraphx::op::identity{}, add);
}
EXPECT(p1 == p2);
}
TEST_CASE(dot_add_beta_float)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
p1.add_instruction(migraphx::op::identity{}, dot);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y);
auto beta =
p2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta_broadcast = p2.add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta);
auto mul = p2.add_instruction(migraphx::op::mul{}, z, beta_broadcast);
auto add = p2.add_instruction(migraphx::op::add{}, dot, mul);
p2.add_instruction(migraphx::op::identity{}, add);
}
EXPECT(p1 == p2);
}
TEST_CASE(dot_add_beta_half)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
p1.add_instruction(migraphx::op::identity{}, dot);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y);
auto beta =
p2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast = p2.add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta);
auto mul = p2.add_instruction(migraphx::op::mul{}, z, beta_broadcast);
auto add = p2.add_instruction(migraphx::op::add{}, dot, mul);
p2.add_instruction(migraphx::op::identity{}, add);
}
EXPECT(p1 == p2);
}
TEST_CASE(dot_add_beta_double)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
p1.add_instruction(migraphx::op::identity{}, dot);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y);
auto beta =
p2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto beta_broadcast = p2.add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta);
auto mul = p2.add_instruction(migraphx::op::mul{}, z, beta_broadcast);
auto add = p2.add_instruction(migraphx::op::add{}, dot, mul);
p2.add_instruction(migraphx::op::identity{}, add);
}
EXPECT(p1 == p2);
}
TEST_CASE(dot_add_beta_int)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z);
p1.add_instruction(migraphx::op::identity{}, dot);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -47,6 +47,15 @@ TEST_CASE(test_shape_packed)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_non_packed_single_dim)
{
migraphx::shape s{migraphx::shape::float_type, {1, 64, 35, 35}, {156800, 1225, 35, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_transposed1)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}};
......@@ -172,6 +181,53 @@ TEST_CASE(test_shape_default_copy)
EXPECT(!(s1 != s2));
}
TEST_CASE(test_shape_normalize_standard1)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}};
EXPECT(s.standard());
auto n = s.normalize_standard();
EXPECT(n == s);
}
TEST_CASE(test_shape_normalize_standard2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 64, 35, 35}, {156800, 1225, 35, 1}};
EXPECT(s.standard());
auto n = s.normalize_standard();
EXPECT(n.standard());
EXPECT(n != s);
EXPECT(n.lens() == s.lens());
EXPECT(n.type() == s.type());
}
TEST_CASE(test_shape_normalize_standard3)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.standard());
auto n = s.normalize_standard();
EXPECT(n == s);
}
TEST_CASE(test_shape_normalize_scalar1)
{
migraphx::shape s{migraphx::shape::float_type};
EXPECT(s.standard());
EXPECT(s.scalar());
auto n = s.normalize_standard();
EXPECT(n != s);
EXPECT(n.standard());
EXPECT(not n.scalar());
}
TEST_CASE(test_shape_normalize_scalar2)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {0, 0}};
EXPECT(not s.standard());
EXPECT(s.scalar());
auto n = s.normalize_standard();
EXPECT(n == s);
}
TEST_CASE(test_shape4)
{
migraphx::shape s{migraphx::shape::float_type, {100, 32, 8, 8}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment