Commit 08161428 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add more test cases for better code coverage.

parent 5589227b
......@@ -30,7 +30,7 @@ struct gather
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
check_shapes{inputs, *this}.has(2).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
......
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
#include <basic_ops.hpp>
......@@ -39,4 +41,17 @@ TEST_CASE(non_standard_op)
EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
}
TEST_CASE(transpose_gemm)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto ic = p.add_instruction(migraphx::op::identity{}, c);
p.add_instruction(migraphx::op::dot{}, ic, l);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment