Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
095e49a3
Commit
095e49a3
authored
May 20, 2022
by
turneram
Browse files
Add transposectx and transposeqkv ref tests
parent
7757cfd0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
0 deletions
+61
-0
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+61
-0
No files found.
test/ref_ops_test.cpp
View file @
095e49a3
...
@@ -666,6 +666,67 @@ TEST_CASE(batch_norm_inference_test)
...
@@ -666,6 +666,67 @@ TEST_CASE(batch_norm_inference_test)
EXPECT(migraphx::verify_range(result_vector, gold));
EXPECT(migraphx::verify_range(result_vector, gold));
}
}
TEST_CASE(bert_transpose_ops_test)
{
{
// transposeQKV
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int> bsknh{2, 384, 3, 12, 64};
const int elements = std::accumulate(bsknh.begin(), bsknh.end(), 1, std::multiplies<int>());
migraphx::shape sh{migraphx::shape::float_type, bsknh};
std::vector<float> data(elements);
std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{sh, data});
mm->add_instruction(migraphx::make_op("transposeqkv"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(elements);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(elements);
migraphx::program p2;
auto* mm2 = p2.get_main_module();
auto l2 = mm2->add_literal(migraphx::literal{sh, data});
// BSKNH->KBNSH : perm=2,0,3,1,4
mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}),
l2);
p2.compile(migraphx::ref::target{});
auto result2 = p2.eval({}).back();
result2.visit([&](auto output) { gold.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vector, gold));
}
{
// transposeCtx
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int> bnsh{2, 12, 384, 64};
const int elements = std::accumulate(bnsh.begin(), bnsh.end(), 1, std::multiplies<int>());
migraphx::shape sh{migraphx::shape::float_type, bnsh};
std::vector<float> data(elements);
std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{sh, data});
mm->add_instruction(migraphx::make_op("transposectx"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(elements);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(elements);
migraphx::program p2;
auto* mm2 = p2.get_main_module();
auto l2 = mm2->add_literal(migraphx::literal{sh, data});
// BNSH->BSNH : perm=0,2,1,3
mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), l2);
p2.compile(migraphx::ref::target{});
auto result2 = p2.eval({}).back();
result2.visit([&](auto output) { gold.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vector, gold));
}
}
TEST_CASE(broadcast_test)
TEST_CASE(broadcast_test)
{
{
migraphx::program p;
migraphx::program p;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment