Commit 095e49a3 authored by turneram's avatar turneram
Browse files

Add transposectx and transposeqkv ref tests

parent 7757cfd0
......@@ -666,6 +666,67 @@ TEST_CASE(batch_norm_inference_test)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(bert_transpose_ops_test)
{
{
// transposeQKV
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int> bsknh{2, 384, 3, 12, 64};
const int elements = std::accumulate(bsknh.begin(), bsknh.end(), 1, std::multiplies<int>());
migraphx::shape sh{migraphx::shape::float_type, bsknh};
std::vector<float> data(elements);
std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{sh, data});
mm->add_instruction(migraphx::make_op("transposeqkv"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(elements);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(elements);
migraphx::program p2;
auto* mm2 = p2.get_main_module();
auto l2 = mm2->add_literal(migraphx::literal{sh, data});
// BSKNH->KBNSH : perm=2,0,3,1,4
mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}),
l2);
p2.compile(migraphx::ref::target{});
auto result2 = p2.eval({}).back();
result2.visit([&](auto output) { gold.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vector, gold));
}
{
// transposeCtx
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int> bnsh{2, 12, 384, 64};
const int elements = std::accumulate(bnsh.begin(), bnsh.end(), 1, std::multiplies<int>());
migraphx::shape sh{migraphx::shape::float_type, bnsh};
std::vector<float> data(elements);
std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{sh, data});
mm->add_instruction(migraphx::make_op("transposectx"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(elements);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(elements);
migraphx::program p2;
auto* mm2 = p2.get_main_module();
auto l2 = mm2->add_literal(migraphx::literal{sh, data});
// BNSH->BSNH : perm=0,2,1,3
mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), l2);
p2.compile(migraphx::ref::target{});
auto result2 = p2.eval({}).back();
result2.visit([&](auto output) { gold.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vector, gold));
}
}
TEST_CASE(broadcast_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment