Commit eb0d8fee authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into driver

parents 65ef35cd 0d796941
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "VersionsProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework";
// Version information for a piece of serialized data
//
// There are different types of versions for each type of data
// (GraphDef, etc.), but they all have the same common shape
// described here.
//
// Each consumer has "consumer" and "min_producer" versions (specified
// elsewhere). A consumer is allowed to consume this data if
//
// producer >= min_producer
// consumer >= min_consumer
// consumer not in bad_consumers
//
message VersionDef {
// The version of the code that produced this data.
int32 producer = 1;
// Any consumer below this version is not allowed to consume this data.
int32 min_consumer = 2;
// Specific consumer versions which are disallowed (e.g. due to bugs).
repeated int32 bad_consumers = 3;
};
......@@ -113,13 +113,27 @@ endif()
# Onnx test
set(TEST_ONNX_DIR ${CMAKE_CURRENT_SOURCE_DIR}/onnx)
add_executable(test_onnx ${TEST_ONNX_DIR}/onnx_test.cpp)
rocm_clang_tidy_check(test_onnx)
target_link_libraries(test_onnx migraphx_onnx)
target_include_directories(test_onnx PUBLIC include)
add_test(NAME test_onnx COMMAND $<TARGET_FILE:test_onnx> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx)
add_dependencies(tests test_onnx)
add_dependencies(check test_onnx)
file (GLOB ONNX_TESTS ${TEST_ONNX_DIR}/*.cpp)
foreach(ONNX_TEST ${ONNX_TESTS})
get_filename_component(BASE_NAME ${ONNX_TEST} NAME_WE)
set(TEST_NAME test_${BASE_NAME})
add_executable(${TEST_NAME} ${TES_ONNX_DIR}/${ONNX_TEST})
rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx_onnx)
target_include_directories(${TEST_NAME} PUBLIC include)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx)
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
endforeach()
# tf test
add_executable(test_tf tf/tf_test.cpp)
rocm_clang_tidy_check(test_tf)
target_link_libraries(test_tf migraphx_tf)
target_include_directories(test_tf PUBLIC include)
add_test(NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/tf)
add_dependencies(tests test_tf)
add_dependencies(check test_tf)
if(MIGRAPHX_ENABLE_PYTHON)
add_subdirectory(py)
......
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -59,7 +60,7 @@ TEST_CASE(after_literal_broadcast)
auto l2 = p.add_literal(get_2());
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape()}, l2);
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted());
......@@ -90,7 +91,7 @@ TEST_CASE(after_param_broadcast)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {2}});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape()}, l2);
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted());
......
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/add.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
#include <migraphx/half.hpp>
template <class T>
void matmul_test()
{
migraphx::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {4, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(matmul_test<float>)
TEST_CASE_REGISTER(matmul_test<double>)
template <class T>
void matmul_test_ex()
{
migraphx::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {1, 1, 4, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {1, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(matmul_test_ex<float>)
TEST_CASE_REGISTER(matmul_test_ex<double>)
TEST_CASE(matmul_mutli_dim_2)
{
migraphx::program p;
std::vector<float> m1 = {-0.76234141,
0.01368910,
-0.86343423,
-0.99465282,
0.76133268,
0.96507140,
-0.55893585,
0.02625652,
0.75171776,
0.23112578,
0.25624787,
-1.50442161};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
-1.06971832, -0.12272917, 1.07094116, -0.08346820, 1.16820693,
-0.95700874, 0.24059691, 0.43326023, 0.78305235, -0.53506601,
-0.69359678, -0.26334436, 1.56292796, -0.33629175, -1.72693469,
0.41435494, 1.52136843, -0.40699791, -1.59839430};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
p.add_instruction(migraphx::op::dot{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim_2_beta0)
{
migraphx::program p;
std::vector<float> m1 = {-0.76234141,
0.01368910,
-0.86343423,
-0.99465282,
0.76133268,
0.96507140,
-0.55893585,
0.02625652,
0.75171776,
0.23112578,
0.25624787,
-1.50442161};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
-1.06971832, -0.12272917, 1.07094116, -0.08346820, 1.16820693,
-0.95700874, 0.24059691, 0.43326023, 0.78305235, -0.53506601,
-0.69359678, -0.26334436, 1.56292796, -0.33629175, -1.72693469,
0.41435494, 1.52136843, -0.40699791, -1.59839430};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
std::vector<float> m3 = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 2, 4}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 1.0f;
float beta = 0.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_beta_0)
{
migraphx::program p;
std::vector<float> m1 = {
-0.76234141, 0.01368910, -0.86343423, -0.99465282, 0.76133268, 0.96507140};
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
std::vector<float> m2 = {-0.15933632,
-0.69594712,
-0.06198966,
-1.23905184,
-0.83672704,
-1.06971832,
-0.12272917,
1.07094116,
-0.08346820,
1.16820693,
-0.95700874,
0.24059691};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
std::vector<float> m3 = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 1.0f;
float beta = 0.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(matmul_mutli_dim_2_3)
{
migraphx::program p;
std::vector<float> m1 = {
-1.93300070, 0.33902698, -0.45173527, -0.72283069, -0.17177134, 1.62199882,
0.87052847, 0.14989811, -0.88969184, -0.18131398, 0.72654339, -0.57123693,
0.03852506, -0.72332085, -1.81844083, -0.33465167, -0.71400352, 0.36883161,
0.08698452, 0.94974586, 0.40087323, -0.05448534, 0.03220677, -1.22494296,
0.97938472, -1.43714454, -0.80430904, -0.08098728, 0.31520301, 0.49642169,
-1.63471091, 0.34390096, 2.81292176, -0.22666528, 1.54559556, -1.51075762};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.33170529, 2.26325120, -0.50639461, 0.64802947, 0.44748888, 0.33768068,
-0.53621075, 0.34341460, 0.58742520, -1.13995790, -0.99322535, 0.35447353,
0.01977110, -0.10155016, -1.02288245, -0.16575791, -1.47870374, 0.29300008,
-0.39112198, 1.42303608, -0.02853060, 1.52610164, 0.53540909, 0.75618998,
-0.26877787, -1.90886366, 0.30622790, 0.59794535, 1.29795331, -0.37805803,
-1.58167176, -1.26966832, 0.27435891, 0.89430347, 0.22854926, -0.50317658};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
p.add_instruction(migraphx::op::dot{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.26735861, -4.30770895, 1.05257728, -1.19954265, 0.50493170,
-0.18729756, 1.09137941, -1.09298312, 3.42956915, -0.41681939,
0.17833257, 0.26040336, 0.15351280, 1.87632715, -0.63545406,
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim1_2_3)
{
migraphx::program p;
std::vector<float> m1 = {
1.23636469, -0.47041261, -0.14375651, -0.48371852, 1.16479301, -0.89361055,
-0.18569086, 1.10700457, -1.02632638, 0.82277012, 0.33525769, 0.52825145,
-1.00141689, 0.45510090, -0.02675039, -0.60454439, 0.38551153, -0.01658514,
0.93059292, -0.54595188, -0.04911005, -0.91397221, -0.83127477, -1.57685603,
-1.36200452, 2.25822236, -1.23416970, 0.12312496, 0.76232760, -0.83594234,
1.67418145, -0.19412936, 1.05261378, 0.66246074, -1.15233398, 0.16429736};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.87300530, -0.07112838, 0.19196860, -1.04986840, 1.20348200, 0.31966893,
1.04805440, -2.04777729, -0.67906052, -1.17250760, 0.34305044, -1.01957785,
-1.12694862, 0.18431338, -1.63712290, 0.27566931, -1.11282021, 1.41738919,
0.47871283, -1.01980420, 1.00212436, -0.78740444, -1.65636133, 1.51466547,
-0.12470397, 0.70404393, -0.15244797, 0.74288871, 0.07339926, -1.45811623,
0.27185845, 0.08804596, 0.99061977, -1.61752428, 0.29191159, 0.87271953};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
std::vector<float> m3 = {-1.07692443, 0.85223457, -0.37266530, 2.31511577, 0.04227017,
1.13229428, -0.52769242, 0.27307182, -0.47779843, -0.08023168,
-0.22862823, 0.81489871, 1.13139581, 1.13860467, 0.24309065,
0.26533729, 0.49106772, -1.18860493, 0.27842449, 1.03568141,
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2);
auto l_beta = p.add_literal(beta);
auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape().lens()}, l_beta);
auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3);
p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586,
0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650,
-0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252,
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_3args)
{
migraphx::program p;
std::vector<float> m1 = {
1.23636469, -0.47041261, -0.14375651, -0.48371852, 1.16479301, -0.89361055,
-0.18569086, 1.10700457, -1.02632638, 0.82277012, 0.33525769, 0.52825145,
-1.00141689, 0.45510090, -0.02675039, -0.60454439, 0.38551153, -0.01658514,
0.93059292, -0.54595188, -0.04911005, -0.91397221, -0.83127477, -1.57685603,
-1.36200452, 2.25822236, -1.23416970, 0.12312496, 0.76232760, -0.83594234,
1.67418145, -0.19412936, 1.05261378, 0.66246074, -1.15233398, 0.16429736};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.87300530, -0.07112838, 0.19196860, -1.04986840, 1.20348200, 0.31966893,
1.04805440, -2.04777729, -0.67906052, -1.17250760, 0.34305044, -1.01957785,
-1.12694862, 0.18431338, -1.63712290, 0.27566931, -1.11282021, 1.41738919,
0.47871283, -1.01980420, 1.00212436, -0.78740444, -1.65636133, 1.51466547,
-0.12470397, 0.70404393, -0.15244797, 0.74288871, 0.07339926, -1.45811623,
0.27185845, 0.08804596, 0.99061977, -1.61752428, 0.29191159, 0.87271953};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
std::vector<float> m3 = {-1.07692443, 0.85223457, -0.37266530, 2.31511577, 0.04227017,
1.13229428, -0.52769242, 0.27307182, -0.47779843, -0.08023168,
-0.22862823, 0.81489871, 1.13139581, 1.13860467, 0.24309065,
0.26533729, 0.49106772, -1.18860493, 0.27842449, 1.03568141,
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586,
0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650,
-0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252,
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_3args)
{
{
migraphx::program p;
std::vector<float> a = {-0.86217194,
-1.04129542,
-0.64850364,
-0.97078327,
-0.40516386,
0.83136927,
0.37717502,
0.42271939,
1.10062165,
-0.92239359,
0.40403076,
-0.43935377};
std::vector<float> b = {0.76084386,
1.89201125,
1.73218067,
0.7148568,
-0.55578914,
0.05799101,
-1.24090721,
-0.51151978,
1.13255803,
0.21540723,
-1.10459009,
0.45580331};
std::vector<float> c = {-0.80473623,
0.35154171,
-2.73077756,
-0.09093885,
-1.88850472,
-0.03375556,
-0.41798276,
2.87368099,
2.11031439};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
migraphx::shape c_shape{migraphx::shape::float_type, {3, 3}};
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::dot{}, al, bl, cl);
std::vector<float> gold = {-1.60947,
0.703083,
-5.46156,
-0.181878,
-3.77701,
-0.0675112,
-0.835966,
5.74736,
4.22063};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(matmul_vv_inner_product)
{
{
migraphx::program p;
std::vector<float> a = {0.7481789,
0.02906279,
1.01193836,
1.60222907,
1.89135978,
0.30054158,
-0.4892588,
-0.27027533};
std::vector<float> b = {-0.25829116,
0.27908929,
-1.27888957,
0.21152361,
0.08593658,
0.52163899,
1.38343824,
-0.2342857};
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
migraphx::shape b_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
p.add_instruction(migraphx::op::dot{}, ual, ubl);
std::vector<float> gold = {-1.43461};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {0.7481789,
0.02906279,
1.01193836,
1.60222907,
1.89135978,
0.30054158,
-0.4892588,
-0.27027533};
std::vector<float> b = {-0.25829116,
0.27908929,
-1.27888957,
0.21152361,
0.08593658,
0.52163899,
1.38343824,
-0.2342857};
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
migraphx::shape b_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
float alpha = 0.32f;
p.add_instruction(migraphx::op::dot{alpha}, ual, ubl);
std::vector<float> gold = {-0.4590752};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(matmul_vm)
{
{
migraphx::program p;
std::vector<float> a = {1.49530002,
-0.07181969,
0.44593846,
-0.8645019,
0.52992304,
-0.4910338,
-2.12179422,
-0.45962977};
std::vector<float> b = {-0.06210242, 0.0187149, 1.47482984, -1.19590602, -0.45601701,
0.36934488, -0.83913193, 0.75350964, 0.80707019, 0.35923582,
-2.18480722, -0.85608682, 0.75849199, 0.49103473, -0.91329477,
-0.36364322, -0.69688937, 0.07165814, -0.15505523, 0.52221663,
-0.98631192, -0.37353654, -1.89818706, -0.87209739, -0.33942003,
0.11390353, 0.78181162, -0.18395337, -0.34743419, -0.08091231,
1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002,
-0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929};
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, ual, bl);
std::vector<float> gold = {-3.78111, -3.40007, -2.1972, -3.31448, -3.80326};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {1.49530002,
-0.07181969,
0.44593846,
-0.8645019,
0.52992304,
-0.4910338,
-2.12179422,
-0.45962977};
std::vector<float> b = {-0.06210242, 0.0187149, 1.47482984, -1.19590602, -0.45601701,
0.36934488, -0.83913193, 0.75350964, 0.80707019, 0.35923582,
-2.18480722, -0.85608682, 0.75849199, 0.49103473, -0.91329477,
-0.36364322, -0.69688937, 0.07165814, -0.15505523, 0.52221663,
-0.98631192, -0.37353654, -1.89818706, -0.87209739, -0.33942003,
0.11390353, 0.78181162, -0.18395337, -0.34743419, -0.08091231,
1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002,
-0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929};
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
float alpha = 0.5f;
p.add_instruction(migraphx::op::dot{alpha}, ual, bl);
std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {
-1.7468318, -0.38900251, 1.00183915, 0.06016438, 0.08295905, 1.5830535};
std::vector<float> b = {
1.2459538, 0.39586199, -0.77035574, 0.22689828, 0.3289835, 1.02804361,
-0.22941113, -0.33940324, 0.80078249, 1.0319152, 0.80034948, -0.11631159,
0.36899208, -0.28506697, -1.2211584, -0.55678377, -0.3618498, 0.34857264,
-0.38700147, -0.43434611, 1.73029783, -0.71578372, 0.09777723, 0.06616614,
-1.66721186, -0.16046032, -1.64581663, 1.09373609, -0.14127692, -0.01938473,
-0.67310303, -1.56154787, -1.0665462, 0.68538535, -1.53920085, -0.35710272,
0.06887234, 0.17474616, 1.08194804, -0.19990148, -0.91149488, 0.95303646,
0.95448717, -0.49332393, -1.762213, -0.56571194, -1.69704968, -0.82798066,
0.65531872, 1.5007798, 0.99877355, 0.53386114, -0.88150609, -1.0756985,
0.50962511, -0.68019002, 0.1583068, 2.83988407, -1.10292457, 0.02126969,
0.21129951, 0.25690146, -1.6490316, 0.55261771, -1.70504303, -0.02870394,
-0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321};
migraphx::shape a_shape{migraphx::shape::float_type, {6}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto bual = p.add_instruction(migraphx::op::multibroadcast{{3, 1, 6}}, ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, bual, bl);
std::vector<float> gold = {1.22914,
-1.17896,
2.28596,
-0.345637,
-0.962362,
0.168508,
-0.947471,
-3.02458,
-3.80131,
1.38484,
-2.45019,
-1.35064};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {
-1.7468318, -0.38900251, 1.00183915, 0.06016438, 0.08295905, 1.5830535};
std::vector<float> b = {
1.2459538, 0.39586199, -0.77035574, 0.22689828, 0.3289835, 1.02804361,
-0.22941113, -0.33940324, 0.80078249, 1.0319152, 0.80034948, -0.11631159,
0.36899208, -0.28506697, -1.2211584, -0.55678377, -0.3618498, 0.34857264,
-0.38700147, -0.43434611, 1.73029783, -0.71578372, 0.09777723, 0.06616614,
-1.66721186, -0.16046032, -1.64581663, 1.09373609, -0.14127692, -0.01938473,
-0.67310303, -1.56154787, -1.0665462, 0.68538535, -1.53920085, -0.35710272,
0.06887234, 0.17474616, 1.08194804, -0.19990148, -0.91149488, 0.95303646,
0.95448717, -0.49332393, -1.762213, -0.56571194, -1.69704968, -0.82798066,
0.65531872, 1.5007798, 0.99877355, 0.53386114, -0.88150609, -1.0756985,
0.50962511, -0.68019002, 0.1583068, 2.83988407, -1.10292457, 0.02126969,
0.21129951, 0.25690146, -1.6490316, 0.55261771, -1.70504303, -0.02870394,
-0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321};
migraphx::shape a_shape{migraphx::shape::float_type, {6}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto bual = p.add_instruction(migraphx::op::multibroadcast{{3, 1, 6}}, ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{0.21f}, bual, bl);
std::vector<float> gold = {0.25812,
-0.247582,
0.480051,
-0.0725837,
-0.202096,
0.0353867,
-0.198969,
-0.635161,
-0.798275,
0.290817,
-0.514539,
-0.283635};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(matmul_mv)
{
{
migraphx::program p;
std::vector<float> a = {0.1612524,
0.61266466,
-0.19212896,
1.34228825,
-1.09746949,
0.4680955,
-0.431748,
-0.89791241,
-2.19078702,
-0.13767058,
-1.66105228,
-0.91834613,
0.59199744,
1.41967261,
0.76237423};
std::vector<float> b = {0.14365572, 0.23401411, -0.8970094, -0.12526676, -1.04703286};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
p.add_instruction(migraphx::op::dot{}, al, ubl);
std::vector<float> gold = {1.31982, 1.19022, -1.96062};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {0.1612524,
0.61266466,
-0.19212896,
1.34228825,
-1.09746949,
0.4680955,
-0.431748,
-0.89791241,
-2.19078702,
-0.13767058,
-1.66105228,
-0.91834613,
0.59199744,
1.41967261,
0.76237423};
std::vector<float> b = {0.14365572, 0.23401411, -0.8970094, -0.12526676, -1.04703286};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
float alpha = 0.3f;
p.add_instruction(migraphx::op::dot{alpha}, al, ubl);
std::vector<float> gold = {0.395946, 0.357067, -0.588187};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {
1.24593227, -0.84351316, 0.27882229, -0.42518484, -1.11391528, 0.59141834,
1.34198714, 2.25884063, -1.32093452, 0.44766336, -0.09306479, 0.47526699,
0.25858488, 1.30820392, 1.17186787, 0.31530864, -1.19159424, -0.24100903,
-1.03857886, 1.54453427, 0.05041654, 1.67108177, 0.965805, 0.52958924,
-1.61243992, 0.02941846, 0.77523836, 1.97963853, -2.51093596, 0.21882645,
-2.60193574, 1.1899952, 1.70883519, 0.94586745, 2.65002512, -1.42427102,
1.0143951, -1.34115312, 1.63833732, -1.46477355, 0.44014877, 0.58032696,
-1.63874372, -0.82834423, 1.81131778, -0.52393379, 1.16721943, 0.39488835,
0.23947128, -0.15733194, 0.19451158, 1.21315445, 0.44594897, 0.40809135,
-0.64252994, 0.7541716, -0.97203195, 0.69208485, 0.34350988, 0.9836842};
std::vector<float> b = {0.05013914, 1.39932885, 2.56616476, 1.02225623, -0.03977829};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
auto bubl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 1}}, ubl);
p.add_instruction(migraphx::op::dot{}, al, bubl);
std::vector<float> gold = {-0.792717,
6.33595,
2.61466,
-3.39322,
5.42485,
3.59084,
6.78139,
-0.360492,
-4.28998,
2.87146,
3.29447,
0.765651};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(matmul_mm1)
{
{
migraphx::program p;
std::vector<float> a = {
-0.49450006, -1.07431991, -0.02796692, -0.99631927, 0.20040449, -1.39709437,
-0.15695328, 0.08208373, -0.09746386, 0.77923021, -0.1849151, 0.14419043,
-0.25798175, -0.2504807, -1.11134383, -0.71030613, -0.20234025, 0.90229168,
0.62643053, -0.83512638, 1.66051254, 0.05941673, 0.73081559, 0.27111867,
0.55060745, 0.34999583, 1.02236619, 0.60178395, 1.49646162, 1.93255155,
-3.65357913, -1.38059906, -0.46302398, 0.19847152, 0.39785875, 1.47004861,
-1.24482133, -0.01954702, 0.36073898, 1.56055978, -0.10344603, -0.34283135,
-0.56482649, 1.80861249, -0.92268202, 0.94371182, -0.02373232, -0.75441145,
0.43325034, 0.4057425, -0.48844822, -0.36390512, 0.74110406, 1.25158366,
0.52196654, 1.43461691, -0.57530864, -0.66716206, -1.76516289, 0.96582849};
std::vector<float> b = {0.49899375,
-2.20168661,
1.08895066,
-0.01135643,
0.90570669,
-1.43550963,
-1.73033377,
0.21338776,
0.96962508,
0.38913968,
-0.32822861,
0.88222863,
0.93330718,
-1.24265228,
-1.62587164};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl);
p.add_instruction(migraphx::op::dot{}, al, bbl);
std::vector<float> gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938,
-0.782212, 1.9459, 0.927426, -2.44907, 2.40531, 2.30232,
0.182745, -4.21937, 1.77551, 1.50775, -2.60888, -2.32484,
-0.557691, 6.13527, -2.91743, 2.37836, -6.42584, 1.14979,
0.77227, 0.349659, 2.92759, 2.32384, -2.90664, 0.0527679,
-0.547761, -0.155467, 0.964619, 2.09133, -4.44281, -1.3864};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {-0.0309568,
-1.57294749,
-0.00768606,
1.5786921,
0.50519718,
0.10530702,
-0.05302112,
-0.06503757,
0.4079716,
0.0799132,
-0.82624962,
0.49341502};
std::vector<float> b = {
0.3664867, 0.24649534, 1.14728076, 1.09911548, -1.23711247, -0.49436419,
-0.67557879, -0.84180575, -1.09754376, 0.07807351, 0.74349043, -0.92084701,
0.50267885, 0.78709401, 0.80598159, -0.51269589, -0.40337193, 0.29457878,
1.25447301, -1.66251457, -1.54652239, -0.35067765, -0.5214464, -0.7866878,
1.11128573, 0.26927291, -0.0929818, 0.07523954, 0.3256776, -1.08617826,
0.89294253, -0.91007619, -2.42825765, -1.76805581, 1.08136334, -0.14521253,
-1.32061148, 0.60663124, -1.19835255, -0.98803563, -1.06927896, -0.51967419,
-0.98974639, 1.01287011, 1.34910394, 0.1203349, 0.67387452, -0.32447465,
1.15187449, -0.82253807, 0.22302433, 0.46434695, 0.319647, 1.56459445,
0.15664012, 0.03998102, 0.62981041, 0.11831296, 0.47824434, -0.93941882,
-0.34674036, 1.17071104, 0.59203806, 2.75817738, -0.69300013, 1.30971899,
-0.14231862, -1.90915568, -0.06895489, 0.20160375, 0.01945916, 0.03586956};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bal = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, bal, bl);
std::vector<float> gold = {
-1.61175, 3.11849, -0.703205, 0.331635, -0.00946922, 0.645626, 0.834069, 1.06409,
0.881037, 0.227628, -0.200308, -1.71836, 0.156255, 0.477222, 0.571363, -1.04543,
1.40524, 1.24201, -2.95083, 1.19352, 1.5008, 0.636987, 0.148256, -0.0231631,
-1.15079, 1.42139, 1.80996, 1.79259, 2.7192, 0.331902, -0.726565, 0.0963351,
-0.710558, 0.259424, -0.342345, -1.80522, -0.580476, 0.277368, -3.95582, 0.614823,
-0.415107, 0.305138, 0.435993, -0.107089, -0.767885, -4.00837, 1.09921, -2.02129,
0.109717, 0.618422, 0.438342, 0.29602, 2.00928, 0.420871};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(matmul_mm2)
{
{
migraphx::program p;
std::vector<float> a = {
-0.49450006, -1.07431991, -0.02796692, -0.99631927, 0.20040449, -1.39709437,
-0.15695328, 0.08208373, -0.09746386, 0.77923021, -0.1849151, 0.14419043,
-0.25798175, -0.2504807, -1.11134383, -0.71030613, -0.20234025, 0.90229168,
0.62643053, -0.83512638, 1.66051254, 0.05941673, 0.73081559, 0.27111867,
0.55060745, 0.34999583, 1.02236619, 0.60178395, 1.49646162, 1.93255155,
-3.65357913, -1.38059906, -0.46302398, 0.19847152, 0.39785875, 1.47004861,
-1.24482133, -0.01954702, 0.36073898, 1.56055978, -0.10344603, -0.34283135,
-0.56482649, 1.80861249, -0.92268202, 0.94371182, -0.02373232, -0.75441145,
0.43325034, 0.4057425, -0.48844822, -0.36390512, 0.74110406, 1.25158366,
0.52196654, 1.43461691, -0.57530864, -0.66716206, -1.76516289, 0.96582849};
std::vector<float> b = {-1.12211357, 1.74720423, 0.60382572, -0.61090125, -0.3315936,
0.30924675, -0.28906435, 0.64039247, -1.2822253, 0.55899286,
2.14013013, 1.00944809, 0.21660017, -0.75465098, 0.12097934,
-1.64006315, 0.43582108, -0.64348541, 0.43101069, 1.30191386,
1.7746011, 0.24935804, 0.42830791, -0.13593643, 0.38749427,
1.39776254, -0.42911717, -1.3537624, -0.81999648, -0.1754485};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl);
std::vector<float> gold = {
0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259,
-0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934,
-1.84001907, 3.51427391, 0.42425673, 0.0638482, 2.40210271, 1.50027643,
4.81988916, -3.63687142, -0.19101717, -4.92522092, -1.76377022, -3.58095615,
1.83096922, 2.5512663, -1.07926588, -2.12749134, 0.33014536, -0.80393025,
0.60740202, 0.95217761, -1.06087445, -4.75868152, -3.6687713, -1.26539821};
p.add_instruction(migraphx::op::dot{}, al, bbl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {-0.19276159, -1.2568421, -0.321242, 1.21471077, -0.4927751,
0.69446894, -0.1786371, -1.00763473, -0.10279314, 3.02931355,
1.08359235, -0.35190132, -0.00639111, 0.78989113, 1.23538029,
0.4590747, 0.17304142, 0.42512412, 0.21076913, -0.01724556,
-0.17763898, 0.12852236, -0.00459301, 1.34498824, 0.02907823,
0.1784464, -0.20790355, -0.52336699, 0.45804085, 1.06025801};
std::vector<float> b = {-1.12211357, 1.74720423, 0.60382572, -0.61090125, -0.3315936,
0.30924675, -0.28906435, 0.64039247, -1.2822253, 0.55899286,
2.14013013, 1.00944809, 0.21660017, -0.75465098, 0.12097934,
-1.64006315, 0.43582108, -0.64348541, 0.43101069, 1.30191386,
1.7746011, 0.24935804, 0.42830791, -0.13593643, 0.38749427,
1.39776254, -0.42911717, -1.3537624, -0.81999648, -0.1754485};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bal = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 3, 5}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl);
p.add_instruction(migraphx::op::dot{}, bal, bbl);
std::vector<float> gold = {
1.64924590e+00, 2.84575831e+00, 1.07340773e+00, 2.19817080e-01, -1.87873283e+00,
1.91883003e+00, -2.89962196e-01, 2.76404142e+00, 1.50048102e+00, -6.29650347e-01,
1.48105185e+00, -3.71716505e-03, 8.80281500e-01, 2.50057585e+00, 1.29958508e+00,
5.63751779e-01, 2.25703781e-01, 1.30516919e+00, 8.32118386e-01, 2.44050864e-01,
-2.49748221e+00, -5.60803176e+00, -2.98919069e+00, -1.11429417e+00, -3.29675989e+00,
1.02442564e-01, -1.87659303e+00, -4.67302454e-01, 9.16189968e-01, -1.33537175e-01,
8.27398578e-01, 1.94406914e+00, -2.39250915e-01, -1.77062701e+00, -6.46239534e-01,
-7.95202750e-01};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {
-0.55248691, 0.70275958, 0.56967633, 0.88206033, -0.85088547, 0.05689149,
-0.20084703, 0.18024434, 1.0730491, 0.15913531, 0.93621628, 0.35072771,
1.28616952, 1.55384379, 0.30376261, -1.12356544, -0.64271552, -2.50703079,
-0.23994372, 0.8166084, 0.06542249, -0.17472336, -0.37665211, 0.16342699,
0.07645941, 0.65024333, -1.19883423, -0.40536776, -0.31132765, 0.78113691,
-0.16887638, 2.30797418, -0.36241233, 0.33552153, -1.05343996, -0.16909699,
-1.22608815, 1.64165613, 0.96260828, -0.16733976, 0.84211199, 1.31243813,
0.89258549, -0.48250384, -1.06005206, 1.37021342, -0.35658565, 0.26879188};
std::vector<float> b = {
0.17111129, -0.82134741, -1.58001178, -1.46759447, 0.31522514, -0.11567352,
-0.038978, -0.3601414, -0.84379876, 0.24848939, -0.37080544, 0.00838631,
1.51316241, 0.42385344, 2.06043846, 1.82348849, 1.07180434, 0.6567393,
1.41164561, 0.73091185, -0.33541302, -0.98082287, -0.06605479, 0.82219717,
-1.41619634, 0.51326658, 0.26916313, 0.79819769, 0.85583702, 0.07876046,
-0.42375545, -0.7758751, 1.14334296, -0.14211708, -1.54520411, -0.55244869,
-0.48478899, 0.10782164, -0.20879552, -0.99019754, 1.78783102, -1.31610052,
1.73510175, -0.48360172, 0.62367417, -1.34180545, -0.37512931, -1.50521357,
0.08383314, 0.76165608, -0.4961646, 0.95821311, -0.68407191, 0.48299435,
-0.24323988, 0.34793412, 0.37908669, 1.19083454, 1.30218795, -0.26731035,
-0.34544132, -0.09595373, 0.50951334, 0.48896956, 0.38753818, -0.4939919,
0.02352126, 0.42013764, 0.07027765, 0.21169851, -0.24411376, -1.77793736,
-0.88370924, 0.95294025, -0.08208804, -0.95943892, 0.30280474, 1.1967013,
-1.17700948, 0.29533973};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {2, 2, 4, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
std::vector<float> gold = {
1.22136035, 1.3765651, 2.0611395, 1.70445494, 1.8189619, 0.2509717,
0.88815736, 1.13837946, 1.37006127, -0.53617378, 0.45759693, -0.503786,
-0.10575749, -0.81715738, 2.56316255, 0.85812927, -0.53425671, 1.38147704,
2.57874755, -1.05591061, -1.42065674, -0.25412658, -2.14494165, -2.81045272,
0.27491485, -0.04229986, 0.10181043, -0.55680682, -0.07633866, 0.313767,
-0.28202571, -1.64696179, -0.50872733, -1.08935912, 0.94291084, -0.71792156,
0.82981387, 1.14797592, 3.13989358, -0.17507726, -0.63429162, -0.72241531,
-0.61459168, -0.52561056, 0.3309648, -0.46185697, -1.60586695, -0.98590829,
0.63012062, -0.25606052, -0.69419352, -1.78299913, -0.38572706, 1.92249442,
0.3884186, -0.48153048, 0.84932351, 0.67234919, -1.07821322, -0.01208216};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {
-0.55248691, 0.70275958, 0.56967633, 0.88206033, -0.85088547, 0.05689149,
-0.20084703, 0.18024434, 1.0730491, 0.15913531, 0.93621628, 0.35072771,
1.28616952, 1.55384379, 0.30376261, -1.12356544, -0.64271552, -2.50703079,
-0.23994372, 0.8166084, 0.06542249, -0.17472336, -0.37665211, 0.16342699,
0.07645941, 0.65024333, -1.19883423, -0.40536776, -0.31132765, 0.78113691,
-0.16887638, 2.30797418, -0.36241233, 0.33552153, -1.05343996, -0.16909699,
-1.22608815, 1.64165613, 0.96260828, -0.16733976, 0.84211199, 1.31243813,
0.89258549, -0.48250384, -1.06005206, 1.37021342, -0.35658565, 0.26879188};
std::vector<float> b = {-0.33734601, 0.66386073, 0.41425048, 0.40190389, -0.99645073,
-0.10017067, -0.58542118, 0.48636962, 0.06301405, 1.14669128,
-0.06526677, 0.23172741, -1.49693143, -0.44464233, -0.12775566,
-1.32038007, 1.1812471, 1.22362746, -0.49013843, 0.25339836,
1.31698705, 1.54256669, 0.11211132, -0.18005487, 0.36730145,
0.97705953, -0.18909084, 0.544932, 0.32891878, 0.64250015,
-0.41381398, 0.47402562, 1.22286761, 1.07573211, -0.92988077,
-0.36340925, -1.76152377, -0.96642674, -0.79231929, 0.11517073};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 4, 5}}, bl);
p.add_instruction(migraphx::op::dot{}, al, bbl);
std::vector<float> gold = {
-1.08585245, 0.39575611, 0.33947977, -0.86339678, 1.50710753, 0.05646156,
-0.43180359, 0.19639674, -0.33742881, 0.98443538, -0.9021272, 1.25043704,
-0.45038184, -0.14689614, -0.91749459, 3.49467934, 3.81336312, 2.4482385,
1.49649707, 1.05889193, -3.49343731, -2.06958956, -2.52082858, -1.61401519,
-1.52966956, 0.01191848, -0.33246613, -0.70641362, -0.60391255, 0.28083355,
0.52255496, -1.08655006, 1.64648546, 0.80344255, 0.71987865, -3.00960296,
2.02318221, 3.32785057, -1.13203844, 1.81235734, 0.38067585, -0.88086897,
1.38307367, 0.42677257, 0.83759966, -0.34827442, -1.45067092, 2.09599671,
1.92882983, -0.30996324, 2.19736278, 2.32389426, 2.36741832, 1.62253915,
0.26698225, -0.00741609, -2.53680983, -0.0679954, 0.04499683, 0.85354276};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -3,10 +3,12 @@
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
#include <migraphx/half.hpp>
float sigmoid(float x) { return 1 / (1 + expf(-x)); }
......@@ -163,6 +165,48 @@ TEST_CASE(gather_test)
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{0};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data{};
std::vector<float> golden = {0.5f, 3.5f, 6.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{0};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data{};
std::vector<float> golden = {0.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
}
TEST_CASE(squeeze_test)
......@@ -608,7 +652,7 @@ TEST_CASE(broadcast_test)
uint64_t axis = 0;
auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2);
p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape().lens()}, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
auto output = result.get<int32_t>();
......@@ -628,7 +672,7 @@ TEST_CASE(add_broadcast_test)
uint64_t axis = 0;
auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2);
auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l1, l3);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
......@@ -766,11 +810,11 @@ TEST_CASE(imagescaler_test)
0.35,
0.45}});
auto scale_val = p.add_literal(2.f);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor);
auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
......@@ -833,55 +877,6 @@ TEST_CASE(reshape_test)
}
}
template <class T>
void gemm_test()
{
migraphx::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {4, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<T> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(gemm_test<float>)
TEST_CASE_REGISTER(gemm_test<double>)
TEST_CASE(maxpool_test)
{
migraphx::program p;
......@@ -993,6 +988,176 @@ TEST_CASE(softmax_test)
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_0)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-2.71138556, -5.85030702, -3.74063578, -4.22915517, -6.15821977, -5.96072346, -3.57208097,
-5.78313166, -5.51435497, -3.67224195, -3.88393048, -2.57061599, -5.54431083, -6.27880025,
-5.1878749, -6.1318955, -5.29178545, -4.22537886, -3.75693516, -7.07047099, -4.45763333,
-4.66281846, -6.18290503, -4.11886536, -6.17408292, -4.18030052, -4.64570814, -4.64354473,
-3.06629525, -3.80807681, -4.69162374, -5.53605222, -3.20969275, -4.82645674, -6.63942356,
-4.73634471, -3.86003866, -5.32738981, -4.22249802, -4.51258693, -2.41455206, -3.48343199,
-5.86215889, -4.93435935, -4.83713408, -2.97471885, -2.16666459, -3.69133151, -4.71640968,
-5.64652924, -3.60709827, -5.87967748, -3.8809403, -4.33917815};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 0;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_1)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-1.77931988, -4.91824134, -2.80857010, -3.29708949, -5.22615409, -5.02865778, -2.64001529,
-4.85106598, -4.58228929, -2.74017627, -2.95186480, -1.63855031, -4.61224515, -5.34673457,
-4.25580922, -5.19982982, -4.35971977, -3.29331318, -2.82486948, -6.13840531, -3.52556765,
-3.73075278, -5.25083935, -3.18679968, -5.24201724, -3.24823484, -3.71364246, -4.14309917,
-2.56584969, -3.30763125, -4.19117818, -5.03560666, -2.70924719, -4.32601118, -6.13897800,
-4.23589915, -3.35959310, -4.82694425, -3.72205246, -4.01214137, -1.91410650, -2.98298643,
-5.36171333, -4.43391379, -4.33668852, -2.47427329, -1.66621903, -3.19088595, -4.21596412,
-5.14608368, -3.10665271, -5.37923192, -3.38049474, -3.83873259};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_2)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-0.79763715, -3.93655861, -1.82688737, -2.31540676, -4.24447136, -4.04697505, -1.65833256,
-3.86938325, -3.60060656, -1.81223672, -2.02392525, -0.71061076, -3.68430560, -4.41879502,
-3.32786967, -4.27189027, -3.43178022, -2.36537363, -1.35498658, -4.66852241, -2.05568475,
-2.26086988, -3.78095645, -1.71691678, -3.77213434, -1.77835194, -2.24375956, -2.74631770,
-1.16906822, -1.91084978, -2.79439671, -3.63882519, -1.31246572, -2.92922971, -4.74219653,
-2.83911768, -2.19738500, -3.66473615, -2.55984436, -2.84993327, -0.75189840, -1.82077833,
-4.19950523, -3.27170569, -3.17448042, -1.65286841, -0.84481415, -2.36948107, -3.39455924,
-4.32467880, -2.28524783, -4.55782704, -2.55908986, -3.01732771};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 2;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_3)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-0.33690375, -3.47582521, -1.36615397, -0.27936556, -2.20843016, -2.01093385, -0.22551114,
-2.43656183, -2.16778514, -1.57241522, -1.78410375, -0.47078926, -1.06745881, -1.80194823,
-0.71102288, -2.30719726, -1.46708721, -0.40068062, -0.42698261, -3.74051844, -1.12768078,
-1.07891856, -2.59900513, -0.53496546, -2.56139951, -0.56761711, -1.03302473, -2.09771276,
-0.52046328, -1.26224484, -1.76322959, -2.60765807, -0.28129860, -0.81424303, -2.62720985,
-0.72413100, -0.65570381, -2.12305496, -1.01816317, -2.48063402, -0.38259915, -1.45147908,
-1.84310238, -0.91530284, -0.81807757, -1.31692881, -0.50887455, -2.03354147, -1.48767160,
-2.41779116, -0.37836019, -2.56853147, -0.56979429, -1.02803214};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 3;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_4)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 4;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(conv2d_test)
{
migraphx::program p;
......@@ -1375,4 +1540,67 @@ TEST_CASE(pad_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(fp16_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {1}};
migraphx::half a{1.5};
migraphx::half b{2.5};
migraphx::half c{4.0};
auto l0 = p.add_literal(migraphx::literal{s, {a}});
auto l1 = p.add_literal(migraphx::literal{s, {b}});
p.add_instruction(migraphx::op::add{}, l0, l1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<migraphx::half> results_vector(1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<migraphx::half> gold{c};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(fp32_fp16_test)
{
auto create_program = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data(2 * 3);
std::iota(data.begin(), data.end(), 1.0f);
auto l1 = p.add_literal(migraphx::literal(s, data));
auto l2 = p.add_literal(migraphx::literal(s, data));
p.add_instruction(migraphx::op::add{}, l1, l2);
return p;
};
auto test_case = [&](std::vector<std::string>&& op_names) {
std::vector<float> gold_res = {2.0, 4.0, 6.0, 8.0, 10.0, 12.0};
auto p = create_program();
migraphx::quantize(p, op_names);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res;
result.visit([&](auto output) { res.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res, gold_res));
};
test_case({"all"});
test_case({"add"});
}
TEST_CASE(clip_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraphx::literal{s, {-1.0, 0.0, 10.0}});
migraphx::op::clip op;
op.max_val = 6.0;
op.min_val = 0.0;
p.add_instruction(op, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0, 0.0, 6.0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/abnormal_ops.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp>
......@@ -2120,4 +2125,1242 @@ TEST_CASE(gru_bidirectional_actv_funcs)
}
}
TEST_CASE(lstm_forward)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
-0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
-0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104,
-1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// forward, hidden state concatenation as output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
und);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.0417273, -0.272355, 0.206765, 0.223879, 0.138193, -0.0322939, -0.0891815,
0.15773, 0.19139, -0.127708, -0.409371, -0.136186, 0.0742487, -0.0800085,
0.259897, 0.0670196, 0.184266, 0.0610048, -0.138041, 0.0963885, 0.0213755,
-0.146027, -0.0324509, -0.0620429, -0.00532985, 0.0440265, 0.29654, -0.0463156,
0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434,
0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113,
0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// forward, last_output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto last_hs = p.eval({});
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.0847427,
0.0874114,
0.304256,
-0.0585745,
-0.0223018,
0.131113,
0.135643,
-0.0566208,
0.142701,
0.0342236,
-0.198664,
0.0702607};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// forward, last_cell_output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
und);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto last_hs = p.eval({});
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.111454,
0.247794,
0.471087,
-0.220574,
-0.048196,
0.263184,
0.283258,
-0.14882,
0.605585,
0.078598,
-0.64457,
0.119811};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_forward_more)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
-0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
-0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104,
-1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// forward, 3 args
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto last_hs = p.eval({});
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361,
0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.0786602, -0.0613048,
0.179592, -0.071286, 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633,
-0.0552699, 0.0252681, -0.0562072, -0.102509, -0.0372696, 0.252296, -0.144544,
0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202,
0.0828391, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// forward, 8 args
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.022618, -0.121195, -0.4065, -0.252054, 0.186991, -0.0624168,
0.205513, 0.0836373, 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906,
-0.0890598, -0.135266, -0.0413375, 0.0459032, 0.0414126, 0.272303, 0.0393149,
0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408,
0.0401075, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544,
0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// seq_len = 1
{
seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.079753,
-0.289854,
0.160043,
0.115056,
0.294074,
-0.0319677,
-0.0955337,
0.104168,
0.022618,
-0.121195,
-0.4065,
-0.252054};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(lstm_reverse)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
-0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390,
-0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914,
-0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401,
0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169,
-0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
-0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707,
0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430,
-0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365,
0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360,
0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291,
-0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910,
0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632,
-0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896,
0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470,
-0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.5289,
1.0986,
0.6091,
1.6462,
0.8720,
0.5349,
-0.1962,
-1.7416,
-0.9912,
1.2831,
1.0896,
-0.6959};
std::vector<float> ic_data{-0.8323,
0.3998,
0.1831,
0.5938,
2.7096,
-0.1790,
0.0022,
-0.8040,
0.1578,
0.0567,
0.8069,
-0.5141};
std::vector<float> pph_data{-0.8271,
-0.5683,
0.4562,
-1.2545,
1.2729,
-0.4082,
-0.4392,
-0.9406,
0.7794,
1.8194,
-0.5811,
0.2166};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f;
// reverse, concatenation of hidden states as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909,
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456,
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// reverse, 3 args, last cell output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.443077,
-0.325425,
-0.249367,
-0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// reverse, 3 args, 0 actv function
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(
migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 0},
seq,
w,
r);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.443077,
-0.325425,
-0.249367,
-0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// reverse, 3 args, 1 actv function
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934,
0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231,
0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213,
0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216,
0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// reverse, 3 args, 2 actv functions
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs =
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.132123,
-0.37531,
-0.12943,
-0.00798307,
-0.133882,
-0.0251383,
0.0486486,
-0.0220606,
0.292495,
0.233866,
0.48646,
0.481844};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
{
seq_len = 1;
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.104351,
-0.0471426,
-0.0905753,
0.01506,
0.059797,
0.104239,
-0.0266768,
0.0727547,
-0.146298,
0.070535,
0.327809,
0.407388};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_bidirectional)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715,
-0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351,
0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734,
-0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346,
0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116,
0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425,
0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459,
-0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{
0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274,
-0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067,
0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637,
-0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823,
0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474,
-0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458,
0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332,
1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349,
-0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212,
-0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114,
-0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790,
0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
-0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562,
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// concatenation of hidden states as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157,
0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905,
0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373,
0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266,
-0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685,
0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032,
0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501,
-0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356,
0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878,
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// last hidden state as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// last cell output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, concatenation of hidden states as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361,
0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647,
-0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328,
0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286,
0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681,
-0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636,
0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509,
-0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329,
0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065,
-0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432,
0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544,
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// sequence length is 1, contenation of hidden state as program output
{
migraphx::program p;
seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698,
-0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617,
-0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239,
-0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_bidirectional_actv_func)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715,
-0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351,
0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734,
-0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346,
0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116,
0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425,
0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459,
-0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
// 3 args, 0 actv func
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(
migraphx::op::lstm{
hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip, 0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361,
0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647,
-0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328,
0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286,
0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681,
-0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636,
0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509,
-0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329,
0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065,
-0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432,
0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544,
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, 1 actv func
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.227861, 0.328562, 0.277867, 0.272945, 0.204389, 0.296123, 0.223834, 0.311113,
0.424666, 0.173974, 0.40628, 0.286631, 0.246078, 0.199709, 0.303753, 0.301178,
0.264634, 0.304661, 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186,
0.339438, 0.29655, 0.331832, 0.242338, 0.409384, 0.236272, 0.306045, 0.26269,
0.261246, 0.334357, 0.23622, 0.245288, 0.301937, 0.264893, 0.254353, 0.269231,
0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213,
0.374123, 0.283167, 0.377129, 0.245726, 0.444712, 0.203168, 0.411446, 0.269965,
0.172792, 0.296224, 0.17319, 0.352547, 0.310306, 0.262902, 0.276964, 0.295002,
0.373802, 0.366785, 0.419791, 0.393216, 0.262827, 0.371441, 0.369022, 0.298262,
0.450186, 0.263538, 0.402895, 0.216177, 0.267257, 0.342535, 0.257797, 0.268563,
0.193043, 0.275645, 0.167678, 0.350889, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, 2 actv func
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs =
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, 4 actv func
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661,
0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, 5 actv func
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, 6 actv func
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/dead_code_elimination.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/abnormal_ops.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/identity.hpp>
#include <test.hpp>
struct dce_target
......@@ -129,4 +131,55 @@ TEST_CASE(undefined_test)
EXPECT(result != migraphx::literal{4});
}
TEST_CASE(duplicate_args1)
{
migraphx::program p;
auto l0 = p.add_literal(0);
auto l3 = p.add_literal(3);
p.add_instruction(migraphx::op::add{}, l3, l3);
p.add_instruction(migraphx::op::identity{}, l0);
auto count = std::distance(p.begin(), p.end());
p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) != count);
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
EXPECT(result == migraphx::literal{0});
}
TEST_CASE(duplicate_args2)
{
migraphx::program p;
auto l0 = p.add_literal(0);
auto l3 = p.add_literal(3);
auto sum1 = p.add_instruction(migraphx::op::add{}, l0, l3);
p.add_instruction(migraphx::op::add{}, sum1, l3);
p.add_instruction(migraphx::op::identity{}, l0);
auto count = std::distance(p.begin(), p.end());
p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) != count);
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
EXPECT(result == migraphx::literal{0});
}
TEST_CASE(duplicate_args3)
{
migraphx::program p;
auto l0 = p.add_literal(0);
auto l3 = p.add_literal(3);
auto sum1 = p.add_instruction(migraphx::op::add{}, l0, l3);
auto sum2 = p.add_instruction(migraphx::op::add{}, l0, sum1);
p.add_instruction(migraphx::op::add{}, sum2, l3);
p.add_instruction(migraphx::op::identity{}, l0);
auto count = std::distance(p.begin(), p.end());
p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) != count);
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
EXPECT(result == migraphx::literal{0});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -19,6 +20,13 @@ struct eliminate_allocation_target
struct allocate
{
migraphx::shape s{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.s, "shape"));
}
std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
......
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -8,6 +10,13 @@ struct concat
{
concat(std::size_t axis) { op.axis = axis; }
migraphx::op::concat op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "eliminate_concat::concat"; }
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
......@@ -49,6 +58,13 @@ struct eliminate_concat_target
struct allocate
{
migraphx::shape s{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.s, "shape"));
}
std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
......
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/sin.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -35,7 +40,46 @@ TEST_CASE(non_standard_op)
p.add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(transpose_gemm)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto ic = p.add_instruction(migraphx::op::identity{}, c);
p.add_instruction(migraphx::op::dot{}, ic, l);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
}
TEST_CASE(transpose_standard_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(no_packed_unary_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count - 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/op/identity.hpp>
#include <test.hpp>
struct eliminate_identity_target
{
std::string name() const { return "eliminate_identity"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::eliminate_identity{}};
}
migraphx::context get_context() const { return {}; }
};
TEST_CASE(simple_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto one_identity = p.add_instruction(migraphx::op::identity{}, one);
auto two = p.add_literal(2);
auto two_identity = p.add_instruction(migraphx::op::identity{}, two);
p.add_instruction(sum_op{}, one_identity, two_identity);
p.compile(eliminate_identity_target{});
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
}
TEST_CASE(simple_test_end)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto ans = p.add_instruction(sum_op{}, one, two);
p.add_instruction(migraphx::op::identity{}, ans);
p.compile(eliminate_identity_target{});
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
}
TEST_CASE(simple_test_end_dependency)
{
migraphx::program p;
auto one = p.add_literal(1.0);
auto two = p.add_literal(2.0);
auto three = p.add_literal(3.0);
auto ans = p.add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, ans, three);
p.add_instruction(migraphx::op::identity{}, ans);
p.compile(eliminate_identity_target{});
EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3.0});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <test.hpp>
struct eliminate_pad_target
{
std::string name() const { return "eliminate_pad"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
migraphx::instruction_ref
create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::program& p)
{
size_t f[2] = {1, 1};
std::vector<int32_t> weights(channels * f[0] * f[1]);
migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_weights = p.add_literal(migraphx::literal{s_weights, weights});
return p.add_instruction(migraphx::op::im2col{}, l_img, l_weights);
}
migraphx::instruction_ref
create_conv(migraphx::instruction_ref& l_img,
size_t channels,
migraphx::program& p,
migraphx::op::padding_mode_t padding_mode = migraphx::op::padding_mode_t::default_)
{
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3);
auto l_weights = p.add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op;
op.padding_mode = padding_mode;
return p.add_instruction(op, l_img, l_weights);
}
TEST_CASE(rewrite_test)
{
migraphx::program p;
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = p.add_literal(migraphx::literal{s_img, input});
auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 1, 1, 0, 0, 1, 1}}, l_img);
auto l0 = create_im2col(padded_img, channels, p);
auto l1 = create_conv(padded_img, channels, p);
auto l2 = p.add_instruction(migraphx::op::pooling{}, padded_img);
p.add_instruction(migraphx::op::identity{}, l0, l1, l2);
p.compile(eliminate_pad_target{});
EXPECT(std::none_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
TEST_CASE(rewrite_test_asymmetric)
{
migraphx::program p;
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = p.add_literal(migraphx::literal{s_img, input});
auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 2, 2}}, l_img);
create_im2col(padded_img, channels, p);
p.compile(eliminate_pad_target{});
EXPECT(std::any_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
TEST_CASE(rewrite_test_same_padding)
{
migraphx::program p;
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = p.add_literal(migraphx::literal{s_img, input});
auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 1, 1, 0, 0, 1, 1}}, l_img);
create_conv(padded_img, channels, p, migraphx::op::padding_mode_t::same);
p.compile(eliminate_pad_target{});
EXPECT(std::any_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -128,7 +128,7 @@ TEST_CASE(print_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, x, two);
......@@ -142,8 +142,8 @@ TEST_CASE(param_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type});
auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y);
auto result = p.eval(
......@@ -156,8 +156,8 @@ TEST_CASE(param_error_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type});
auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>(
......@@ -167,6 +167,22 @@ TEST_CASE(param_error_test)
"Parameter not found: y"));
}
TEST_CASE(param_shape_error_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 2}});
auto y = p.add_parameter("y", {migraphx::shape::int32_type, {1, 2}});
p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>(
[&] {
p.eval({{"x", migraphx::literal{1}.get_argument()},
{"y", migraphx::literal{2}.get_argument()}});
},
"Incorrect shape"));
}
TEST_CASE(replace_test)
{
migraphx::program p;
......
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/program.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/verify.hpp>
bool is_batch_norm(migraphx::instruction& ins) { return ins.name() == "batch_norm_inference"; }
TEST_CASE(fwd_conv_batchnorm_rewrite_test)
{
std::vector<float> xdata = {
......@@ -65,4 +71,105 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
EXPECT(migraphx::verify_range(results_vector1, results_vector2));
}
TEST_CASE(non_literal)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape vars{migraphx::shape::float_type, {4}};
auto create_program = [&]() {
migraphx::program p;
auto x = p.add_parameter("x", xs);
auto w = p.add_parameter("w", ws);
auto conv = p.add_instruction(migraphx::op::convolution{}, x, w);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
};
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(any_of(p2, &is_batch_norm));
}
TEST_CASE(as_literal)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape vars{migraphx::shape::float_type, {4}};
auto create_program = [&]() {
migraphx::program p;
auto x = p.add_literal(migraphx::generate_literal(xs, 1));
auto w = p.add_literal(migraphx::generate_literal(ws, 1));
auto conv = p.add_instruction(migraphx::op::convolution{}, x, w);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
};
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({});
auto result2 = p2.eval({});
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
TEST_CASE(literal_reshape)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape vars{migraphx::shape::float_type, {4}};
auto create_program = [&]() {
migraphx::program p;
auto reshape = [&](auto ins) {
return p.add_instruction(migraphx::op::reshape{{1, 4, 1, 1}}, ins);
};
auto x = p.add_literal(migraphx::generate_literal(xs, 1));
auto w = p.add_literal(migraphx::generate_literal(ws, 1));
auto conv = p.add_instruction(migraphx::op::convolution{}, x, w);
auto scale = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))));
auto bias = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))));
auto mean = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))));
auto variance = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))));
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
};
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({});
auto result2 = p2.eval({});
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/tanh.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct lowering_target
{
std::string name() const { return "gpu::lowering"; }
std::vector<migraphx::pass> get_passes(migraphx::context& gctx) const
{
auto& ctx = migraphx::any_cast<migraphx::gpu::context>(gctx);
return {migraphx::auto_contiguous{},
migraphx::gpu::lowering{ctx},
migraphx::dead_code_elimination{},
migraphx::eliminate_contiguous{},
migraphx::dead_code_elimination{}};
}
migraphx::gpu::context get_context() const { return migraphx::gpu::context{}; }
};
TEST_CASE(tanh_shape)
{
auto create_program = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto x = p.add_parameter("x", s);
auto tx = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto txh = p.add_instruction(migraphx::op::tanh{}, tx);
auto sum = p.add_instruction(migraphx::op::add{}, txh, txh);
p.add_instruction(migraphx::op::contiguous{}, sum);
return p;
};
auto p1 = create_program();
auto p2 = create_program();
EXPECT(p1 == p2);
p1.compile(lowering_target{});
p2.compile(lowering_target());
EXPECT(p1 == p2);
for(auto ins : iterator_for(p1))
{
if(ins->name() == "hip::allocate")
{
migraphx::shape new_s{migraphx::shape::float_type, {3, 2}, {1, 3}};
ins->replace(migraphx::gpu::hip_allocate{new_s});
}
}
EXPECT(p1 != p2);
migraphx::run_passes(p2,
{migraphx::gpu::adjust_allocation{}, migraphx::dead_code_elimination{}});
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -10,13 +10,14 @@
#include <migraphx/type_name.hpp>
#include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp>
#include <miopen/miopen.h>
#include <future>
#include <thread>
#include "test.hpp"
#include <test.hpp>
#ifdef __clang__
#pragma clang diagnostic push
......@@ -134,7 +135,7 @@ migraphx::argument run_gpu(migraphx::program& p)
}
template <class V>
void verify_program()
void run_verify_program()
{
auto_print::set_terminate_handler(migraphx::get_type_name<V>());
// std::cout << migraphx::get_type_name<V>() << std::endl;
......@@ -156,7 +157,27 @@ void verify_program()
std::set_terminate(nullptr);
}
struct test_literals
template <class T>
int auto_register_verify_program()
{
test::add_test_case(migraphx::get_type_name<T>(), [] { run_verify_program<T>(); });
return 0;
}
template <class T>
struct verify_program
{
static int static_register;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using static_register_type =
std::integral_constant<decltype(&static_register), &static_register>;
};
template <class T>
int verify_program<T>::static_register = auto_register_verify_program<T>(); // NOLINT
struct test_literals : verify_program<test_literals>
{
migraphx::program create_program() const
{
......@@ -171,7 +192,7 @@ struct test_literals
}
};
struct test_add
struct test_add : verify_program<test_add>
{
migraphx::program create_program() const
{
......@@ -184,7 +205,7 @@ struct test_add
}
};
struct test_add_half
struct test_add_half : verify_program<test_add_half>
{
migraphx::program create_program() const
{
......@@ -197,7 +218,7 @@ struct test_add_half
}
};
struct test_mul
struct test_mul : verify_program<test_mul>
{
migraphx::program create_program() const
{
......@@ -210,33 +231,31 @@ struct test_mul
}
};
struct test_exp
struct test_exp : verify_program<test_exp>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> data{0.1f, 0.2f, 1.f, 2.f, 0.6f, 10.f};
auto x = p.add_literal(s, data);
auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s));
p.add_instruction(migraphx::op::exp{}, x);
return p;
}
};
struct test_log
struct test_log : verify_program<test_log>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> data{0.1f, 0.2f, 1.f, 2.f, 0.6f, 100.f};
auto x = p.add_literal(s, data);
auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s));
p.add_instruction(migraphx::op::log{}, x);
return p;
}
};
struct test_sin
struct test_sin : verify_program<test_sin>
{
migraphx::program create_program() const
{
......@@ -248,7 +267,7 @@ struct test_sin
}
};
struct test_cos
struct test_cos : verify_program<test_cos>
{
migraphx::program create_program() const
{
......@@ -260,7 +279,7 @@ struct test_cos
}
};
struct test_tan
struct test_tan : verify_program<test_tan>
{
migraphx::program create_program() const
{
......@@ -272,7 +291,7 @@ struct test_tan
}
};
struct test_sinh
struct test_sinh : verify_program<test_sinh>
{
migraphx::program create_program() const
{
......@@ -284,7 +303,7 @@ struct test_sinh
}
};
struct test_cosh
struct test_cosh : verify_program<test_cosh>
{
migraphx::program create_program() const
{
......@@ -296,7 +315,7 @@ struct test_cosh
}
};
struct test_tanh
struct test_tanh : verify_program<test_tanh>
{
migraphx::program create_program() const
{
......@@ -307,7 +326,35 @@ struct test_tanh
}
};
struct test_asin
struct test_trans_tanh : verify_program<test_trans_tanh>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto tanhx = p.add_instruction(migraphx::op::tanh{}, tx);
auto r = p.add_instruction(migraphx::op::add{}, tanhx, tanhx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p;
}
};
struct test_slice_sin : verify_program<test_slice_sin>
{
migraphx::program create_program() const
{
migraphx::program p;
auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l);
p.add_instruction(migraphx::op::sin{}, t);
return p;
}
};
struct test_asin : verify_program<test_asin>
{
migraphx::program create_program() const
{
......@@ -319,7 +366,7 @@ struct test_asin
}
};
struct test_acos
struct test_acos : verify_program<test_acos>
{
migraphx::program create_program() const
{
......@@ -331,7 +378,7 @@ struct test_acos
}
};
struct test_atan
struct test_atan : verify_program<test_atan>
{
migraphx::program create_program() const
{
......@@ -343,7 +390,7 @@ struct test_atan
}
};
struct test_scale
struct test_scale : verify_program<test_scale>
{
migraphx::program create_program() const
{
......@@ -351,13 +398,13 @@ struct test_scale
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", migraphx::shape::float_type);
auto scale = p.add_instruction(migraphx::op::scalar{s}, y);
auto scale = p.add_instruction(migraphx::op::scalar{s.lens()}, y);
p.add_instruction(migraphx::op::mul{}, x, scale);
return p;
}
};
struct test_slice
struct test_slice : verify_program<test_slice>
{
migraphx::program create_program() const
{
......@@ -372,7 +419,7 @@ struct test_slice
}
};
struct test_triadd
struct test_triadd : verify_program<test_triadd>
{
migraphx::program create_program() const
{
......@@ -387,7 +434,7 @@ struct test_triadd
}
};
struct test_triadd2
struct test_triadd2 : verify_program<test_triadd2>
{
migraphx::program create_program() const
{
......@@ -397,14 +444,14 @@ struct test_triadd2
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s}, z);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto sum = p.add_instruction(migraphx::op::add{}, x, y);
p.add_instruction(migraphx::op::add{}, sum, zb);
return p;
}
};
struct test_add_broadcast
struct test_add_broadcast : verify_program<test_add_broadcast>
{
migraphx::program create_program() const
{
......@@ -412,13 +459,13 @@ struct test_add_broadcast
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape()}, y);
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by);
return p;
}
};
struct test_add_broadcast2
struct test_add_broadcast2 : verify_program<test_add_broadcast2>
{
migraphx::program create_program() const
{
......@@ -426,13 +473,13 @@ struct test_add_broadcast2
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y);
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by);
return p;
}
};
struct test_add_broadcast3
struct test_add_broadcast3 : verify_program<test_add_broadcast3>
{
migraphx::program create_program() const
{
......@@ -440,13 +487,13 @@ struct test_add_broadcast3
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 5}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y);
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by);
return p;
}
};
struct test_add_broadcast4
struct test_add_broadcast4 : verify_program<test_add_broadcast4>
{
migraphx::program create_program() const
{
......@@ -454,13 +501,13 @@ struct test_add_broadcast4
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 5}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y);
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by);
return p;
}
};
struct test_add_broadcast5
struct test_add_broadcast5 : verify_program<test_add_broadcast5>
{
migraphx::program create_program() const
{
......@@ -468,13 +515,13 @@ struct test_add_broadcast5
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 8}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y);
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by);
return p;
}
};
struct test_triadd_broadcast
struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
{
migraphx::program create_program() const
{
......@@ -483,14 +530,14 @@ struct test_triadd_broadcast
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto z = p.add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}});
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape()}, y);
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y);
auto sum = p.add_instruction(migraphx::op::add{}, x, by);
p.add_instruction(migraphx::op::add{}, sum, z);
return p;
}
};
struct test_sub
struct test_sub : verify_program<test_sub>
{
migraphx::program create_program() const
{
......@@ -505,7 +552,7 @@ struct test_sub
}
};
struct test_sub2
struct test_sub2 : verify_program<test_sub2>
{
migraphx::program create_program() const
{
......@@ -515,14 +562,14 @@ struct test_sub2
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s}, z);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto diff = p.add_instruction(migraphx::op::sub{}, x, y);
p.add_instruction(migraphx::op::sub{}, diff, zb);
return p;
}
};
struct test_softmax
struct test_softmax : verify_program<test_softmax>
{
migraphx::program create_program() const
{
......@@ -533,7 +580,7 @@ struct test_softmax
}
};
struct test_softmax2
struct test_softmax2 : verify_program<test_softmax2>
{
migraphx::program create_program() const
{
......@@ -545,7 +592,7 @@ struct test_softmax2
}
};
struct test_conv
struct test_conv : verify_program<test_conv>
{
migraphx::program create_program() const
{
......@@ -559,7 +606,7 @@ struct test_conv
}
};
struct test_conv2
struct test_conv2 : verify_program<test_conv2>
{
migraphx::program create_program() const
{
......@@ -573,7 +620,7 @@ struct test_conv2
}
};
struct test_group_conv
struct test_group_conv : verify_program<test_group_conv>
{
migraphx::program create_program() const
{
......@@ -589,7 +636,7 @@ struct test_group_conv
}
};
struct test_conv_relu
struct test_conv_relu : verify_program<test_conv_relu>
{
migraphx::program create_program() const
{
......@@ -604,7 +651,7 @@ struct test_conv_relu
}
};
struct test_conv_relu_half
struct test_conv_relu_half : verify_program<test_conv_relu_half>
{
migraphx::program create_program() const
{
......@@ -619,7 +666,7 @@ struct test_conv_relu_half
}
};
struct test_add_relu
struct test_add_relu : verify_program<test_add_relu>
{
migraphx::program create_program() const
{
......@@ -632,7 +679,7 @@ struct test_add_relu
}
};
struct test_sigmoid
struct test_sigmoid : verify_program<test_sigmoid>
{
migraphx::program create_program() const
{
......@@ -643,7 +690,7 @@ struct test_sigmoid
}
};
struct test_abs
struct test_abs : verify_program<test_abs>
{
migraphx::program create_program() const
{
......@@ -654,7 +701,22 @@ struct test_abs
}
};
struct test_leaky_relu
struct test_trans_abs : verify_program<test_trans_abs>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto absx = p.add_instruction(migraphx::op::abs{}, tx);
auto r = p.add_instruction(migraphx::op::add{}, absx, absx);
p.add_instruction(migraphx::op::contiguous{}, r);
return p;
}
};
struct test_leaky_relu : verify_program<test_leaky_relu>
{
migraphx::program create_program() const
{
......@@ -665,7 +727,7 @@ struct test_leaky_relu
}
};
struct test_elu
struct test_elu : verify_program<test_elu>
{
migraphx::program create_program() const
{
......@@ -676,7 +738,7 @@ struct test_elu
}
};
struct test_relu_lrn
struct test_relu_lrn : verify_program<test_relu_lrn>
{
migraphx::program create_program() const
{
......@@ -688,7 +750,7 @@ struct test_relu_lrn
}
};
struct test_conv_pooling
struct test_conv_pooling : verify_program<test_conv_pooling>
{
migraphx::program create_program() const
{
......@@ -704,7 +766,7 @@ struct test_conv_pooling
}
};
struct test_global_avg_pooling
struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
{
migraphx::program create_program() const
{
......@@ -719,7 +781,7 @@ struct test_global_avg_pooling
}
};
struct test_global_max_pooling
struct test_global_max_pooling : verify_program<test_global_max_pooling>
{
migraphx::program create_program() const
{
......@@ -734,7 +796,7 @@ struct test_global_max_pooling
}
};
struct test_gemm
struct test_gemm : verify_program<test_gemm>
{
migraphx::program create_program() const
{
......@@ -746,7 +808,19 @@ struct test_gemm
}
};
struct test_gemm_half
struct test_gemm_ex : verify_program<test_gemm_ex>
{
migraphx::program create_program() const
{
migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 5}});
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}});
p.add_instruction(migraphx::op::dot{}, a, b);
return p;
}
};
struct test_gemm_half : verify_program<test_gemm_half>
{
migraphx::program create_program() const
{
......@@ -758,7 +832,7 @@ struct test_gemm_half
}
};
struct test_gemm_ld
struct test_gemm_ld //: verify_program<test_gemm_ld>
{
migraphx::program create_program() const
{
......@@ -772,7 +846,7 @@ struct test_gemm_ld
}
};
struct test_gemm_transposeb
struct test_gemm_transposeb : verify_program<test_gemm_transposeb>
{
migraphx::program create_program() const
{
......@@ -785,7 +859,20 @@ struct test_gemm_transposeb
}
};
struct test_gemm_transposea
struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex>
{
migraphx::program create_program() const
{
migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}});
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}});
auto bt = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, b);
p.add_instruction(migraphx::op::dot{}, a, bt);
return p;
}
};
struct test_gemm_transposea : verify_program<test_gemm_transposea>
{
migraphx::program create_program() const
{
......@@ -798,7 +885,20 @@ struct test_gemm_transposea
}
};
struct test_gemm_transposeab
struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex>
{
migraphx::program create_program() const
{
migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}});
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}});
auto at = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, a);
p.add_instruction(migraphx::op::dot{}, at, b);
return p;
}
};
struct test_gemm_transposeab : verify_program<test_gemm_transposeab>
{
migraphx::program create_program() const
{
......@@ -812,7 +912,333 @@ struct test_gemm_transposeab
}
};
struct test_contiguous
struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, bl2);
return p;
}
};
struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, bl2);
return p;
}
};
struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
};
struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
};
struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
};
struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, bl1, bl2);
return p;
}
};
struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
};
struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2);
return p;
}
};
struct gemm_2args_vv : verify_program<gemm_2args_vv>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {8}};
migraphx::shape m2_shape{migraphx::shape::float_type, {8}};
auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
float alpha = 0.23f;
auto res = p.add_instruction(migraphx::op::dot{alpha}, ul1, ul2);
auto sres = p.add_instruction(migraphx::op::squeeze{{0}}, res);
p.add_instruction(migraphx::op::squeeze{{0}}, sres);
return p;
}
};
struct gemm_2args_mv : verify_program<gemm_2args_mv>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, ul2);
return p;
}
};
struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
auto bul2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 5, 1}}, ul2);
p.add_instruction(migraphx::op::dot{}, l1, bul2);
return p;
}
};
struct gemm_2args_vm : verify_program<gemm_2args_vm>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto res = p.add_instruction(migraphx::op::dot{}, ul1, l2);
p.add_instruction(migraphx::op::squeeze{{0}}, res);
return p;
}
};
struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto bul1 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 1, 5}}, ul1);
auto l2 = p.add_parameter("2", m2_shape);
auto res = p.add_instruction(migraphx::op::dot{}, bul1, l2);
p.add_instruction(migraphx::op::squeeze{{2}}, res);
return p;
}
};
struct gemm_multi_3args : verify_program<gemm_multi_3args>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 1.0f;
float beta = 0.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.0f;
float beta = 1.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct test_contiguous : verify_program<test_contiguous>
{
migraphx::program create_program() const
{
......@@ -825,7 +1251,7 @@ struct test_contiguous
}
};
struct test_transpose
struct test_transpose : verify_program<test_transpose>
{
migraphx::program create_program() const
{
......@@ -839,7 +1265,7 @@ struct test_transpose
}
};
struct test_batchnorm_inference_2
struct test_batchnorm_inference_2 : verify_program<test_batchnorm_inference_2>
{
const size_t width = 14;
const size_t height = 14;
......@@ -862,7 +1288,7 @@ struct test_batchnorm_inference_2
}
};
struct test_batchnorm_inference
struct test_batchnorm_inference : verify_program<test_batchnorm_inference>
{
const size_t width = 3;
const size_t height = 3;
......@@ -885,7 +1311,18 @@ struct test_batchnorm_inference
}
};
struct test_conv_bn
struct test_clip : verify_program<test_clip>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, x);
return p;
}
};
struct test_conv_bn : verify_program<test_conv_bn>
{
migraphx::program create_program() const
{
......@@ -906,7 +1343,7 @@ struct test_conv_bn
}
};
struct test_conv_bn_relu_pooling
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
{
migraphx::program create_program() const
{
......@@ -930,7 +1367,7 @@ struct test_conv_bn_relu_pooling
}
};
struct test_concat
struct test_concat : verify_program<test_concat>
{
migraphx::program create_program() const
{
......@@ -947,7 +1384,7 @@ struct test_concat
}
};
struct test_concat2
struct test_concat2 : verify_program<test_concat2>
{
migraphx::program create_program() const
{
......@@ -964,7 +1401,7 @@ struct test_concat2
}
};
struct test_concat_relu
struct test_concat_relu : verify_program<test_concat_relu>
{
migraphx::program create_program() const
{
......@@ -985,7 +1422,7 @@ struct test_concat_relu
}
};
struct test_pad
struct test_pad : verify_program<test_pad>
{
migraphx::program create_program() const
{
......@@ -1004,7 +1441,7 @@ struct test_pad
}
};
struct test_pooling_autopad
struct test_pooling_autopad : verify_program<test_pooling_autopad>
{
migraphx::program create_program() const
{
......@@ -1020,7 +1457,7 @@ struct test_pooling_autopad
}
};
struct test_gather
struct test_gather : verify_program<test_gather>
{
migraphx::program create_program() const
{
......@@ -1036,7 +1473,7 @@ struct test_gather
}
};
struct test_gather_neg_axis
struct test_gather_neg_axis : verify_program<test_gather_neg_axis>
{
migraphx::program create_program() const
{
......@@ -1052,10 +1489,58 @@ struct test_gather_neg_axis
}
};
void manual_identity()
struct test_gather_scalar_output : verify_program<test_gather_scalar_output>
{
migraphx::program p;
std::vector<float> data0 = {0, 1, 2, 3};
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
struct test_gather_scalar_index : verify_program<test_gather_scalar_index>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
struct test_gather_1d_index : verify_program<test_gather_1d_index>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {1}};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
void manual_identity()
{
migraphx::program p;
std::vector<float> data0 = {0, 1, 2, 3};
migraphx::shape s0{migraphx::shape::float_type, {2, 2}};
auto l0 = p.add_literal(migraphx::literal{s0, data0});
p.add_instruction(migraphx::op::identity{}, l0);
......@@ -1098,7 +1583,7 @@ void manual_test_concat_relu()
std::cout << result << std::endl;
}
struct test_conv_bn_relu_pooling2
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{
static migraphx::instruction_ref
add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels)
......@@ -1135,7 +1620,7 @@ struct test_conv_bn_relu_pooling2
}
};
struct test_rnn_forward
struct test_rnn_forward : verify_program<test_rnn_forward>
{
migraphx::program create_program() const
{
......@@ -1177,7 +1662,7 @@ struct test_rnn_forward
}
};
struct test_rnn_forward10
struct test_rnn_forward10 : verify_program<test_rnn_forward10>
{
migraphx::program create_program() const
{
......@@ -1219,7 +1704,7 @@ struct test_rnn_forward10
}
};
struct test_rnn_reverse
struct test_rnn_reverse : verify_program<test_rnn_reverse>
{
migraphx::program create_program() const
{
......@@ -1259,7 +1744,7 @@ struct test_rnn_reverse
}
};
struct test_rnn_reverse2
struct test_rnn_reverse2 : verify_program<test_rnn_reverse2>
{
migraphx::program create_program() const
{
......@@ -1299,7 +1784,7 @@ struct test_rnn_reverse2
}
};
struct test_rnn_3args
struct test_rnn_3args : verify_program<test_rnn_3args>
{
migraphx::program create_program() const
{
......@@ -1331,7 +1816,7 @@ struct test_rnn_3args
}
};
struct test_rnn_4args
struct test_rnn_4args : verify_program<test_rnn_4args>
{
migraphx::program create_program() const
{
......@@ -1366,7 +1851,7 @@ struct test_rnn_4args
}
};
struct test_rnn_5args
struct test_rnn_5args : verify_program<test_rnn_5args>
{
migraphx::program create_program() const
{
......@@ -1405,7 +1890,7 @@ struct test_rnn_5args
}
};
struct test_rnn_bidirectional
struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional>
{
migraphx::program create_program() const
{
......@@ -1447,7 +1932,7 @@ struct test_rnn_bidirectional
}
};
struct test_rnn_bidirectional10
struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
{
migraphx::program create_program() const
{
......@@ -1488,7 +1973,7 @@ struct test_rnn_bidirectional10
}
};
struct test_rnn_bi_3args
struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
{
migraphx::program create_program() const
{
......@@ -1523,7 +2008,7 @@ struct test_rnn_bi_3args
}
};
struct test_gru_forward_last
struct test_gru_forward_last : verify_program<test_gru_forward_last>
{
migraphx::program create_program() const
{
......@@ -1567,7 +2052,7 @@ struct test_gru_forward_last
}
};
struct test_gru_forward_hs
struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
{
migraphx::program create_program() const
{
......@@ -1609,7 +2094,7 @@ struct test_gru_forward_hs
}
};
struct test_gru_forward_3args_und
struct test_gru_forward_3args_und : verify_program<test_gru_forward_3args_und>
{
migraphx::program create_program() const
{
......@@ -1645,7 +2130,7 @@ struct test_gru_forward_3args_und
}
};
struct test_gru_forward_3args
struct test_gru_forward_3args : verify_program<test_gru_forward_3args>
{
migraphx::program create_program() const
{
......@@ -1677,7 +2162,7 @@ struct test_gru_forward_3args
}
};
struct test_gru_forward_seq1
struct test_gru_forward_seq1 : verify_program<test_gru_forward_seq1>
{
migraphx::program create_program() const
{
......@@ -1709,7 +2194,7 @@ struct test_gru_forward_seq1
}
};
struct test_gru_forward_default_actv
struct test_gru_forward_default_actv : verify_program<test_gru_forward_default_actv>
{
migraphx::program create_program() const
{
......@@ -1739,7 +2224,7 @@ struct test_gru_forward_default_actv
}
};
struct test_gru_forward_default_actv1
struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_actv1>
{
migraphx::program create_program() const
{
......@@ -1780,7 +2265,7 @@ struct test_gru_forward_default_actv1
}
};
struct test_gru_reverse_last
struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
{
migraphx::program create_program() const
{
......@@ -1824,7 +2309,7 @@ struct test_gru_reverse_last
}
};
struct test_gru_reverse_3args
struct test_gru_reverse_3args : verify_program<test_gru_reverse_3args>
{
migraphx::program create_program() const
{
......@@ -1856,7 +2341,7 @@ struct test_gru_reverse_3args
}
};
struct test_gru_bidirct_last
struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
{
migraphx::program create_program() const
{
......@@ -1900,7 +2385,7 @@ struct test_gru_bidirct_last
}
};
struct test_gru_bidirct_hs
struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
{
migraphx::program create_program() const
{
......@@ -1942,7 +2427,7 @@ struct test_gru_bidirct_hs
}
};
struct test_gru_bidirct_3args_und
struct test_gru_bidirct_3args_und : verify_program<test_gru_bidirct_3args_und>
{
migraphx::program create_program() const
{
......@@ -1978,7 +2463,7 @@ struct test_gru_bidirct_3args_und
}
};
struct test_gru_bidirct_3args
struct test_gru_bidirct_3args : verify_program<test_gru_bidirct_3args>
{
migraphx::program create_program() const
{
......@@ -2010,7 +2495,7 @@ struct test_gru_bidirct_3args
}
};
struct test_gru_bidirct_seq1
struct test_gru_bidirct_seq1 : verify_program<test_gru_bidirct_seq1>
{
migraphx::program create_program() const
{
......@@ -2042,7 +2527,7 @@ struct test_gru_bidirct_seq1
}
};
struct test_gru_bidirct_default_actv
struct test_gru_bidirct_default_actv : verify_program<test_gru_bidirct_default_actv>
{
migraphx::program create_program() const
{
......@@ -2072,7 +2557,7 @@ struct test_gru_bidirct_default_actv
}
};
struct test_gru_bidirct_default_actv1
struct test_gru_bidirct_default_actv1 : verify_program<test_gru_bidirct_default_actv1>
{
migraphx::program create_program() const
{
......@@ -2114,94 +2599,816 @@ struct test_gru_bidirct_default_actv1
}
};
int main()
{
verify_program<test_relu_lrn>();
verify_program<test_pooling_autopad>();
verify_program<test_abs>();
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_concat_relu>();
verify_program<test_pad>();
verify_program<test_add>();
verify_program<test_add_half>();
verify_program<test_mul>();
verify_program<test_exp>();
verify_program<test_log>();
verify_program<test_sin>();
verify_program<test_cos>();
verify_program<test_tan>();
verify_program<test_sinh>();
verify_program<test_cosh>();
verify_program<test_tanh>();
verify_program<test_asin>();
verify_program<test_acos>();
verify_program<test_atan>();
verify_program<test_scale>();
verify_program<test_triadd>();
verify_program<test_triadd2>();
verify_program<test_add_broadcast>();
verify_program<test_add_broadcast2>();
verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>();
verify_program<test_triadd_broadcast>();
verify_program<test_sub>();
verify_program<test_sub2>();
verify_program<test_softmax>();
verify_program<test_softmax2>();
verify_program<test_conv>();
verify_program<test_conv2>();
verify_program<test_group_conv>();
verify_program<test_conv_relu>();
verify_program<test_conv_relu_half>();
verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_sigmoid>();
verify_program<test_elu>();
verify_program<test_conv_pooling>();
verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>();
verify_program<test_gemm>();
verify_program<test_gemm_half>();
// verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposeab>();
verify_program<test_contiguous>();
verify_program<test_transpose>();
verify_program<test_batchnorm_inference>();
verify_program<test_batchnorm_inference_2>();
verify_program<test_conv_bn>();
verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>();
verify_program<test_gather>();
verify_program<test_gather_neg_axis>();
verify_program<test_rnn_forward>();
verify_program<test_rnn_forward10>();
verify_program<test_rnn_reverse>();
verify_program<test_rnn_reverse2>();
verify_program<test_rnn_3args>();
verify_program<test_rnn_4args>();
verify_program<test_rnn_5args>();
verify_program<test_rnn_bidirectional>();
verify_program<test_rnn_bidirectional10>();
verify_program<test_rnn_bi_3args>();
verify_program<test_gru_forward_last>();
verify_program<test_gru_forward_hs>();
verify_program<test_gru_forward_3args_und>();
verify_program<test_gru_forward_3args>();
verify_program<test_gru_forward_seq1>();
verify_program<test_gru_forward_default_actv>();
verify_program<test_gru_forward_default_actv1>();
verify_program<test_gru_reverse_last>();
verify_program<test_gru_reverse_3args>();
verify_program<test_gru_bidirct_last>();
verify_program<test_gru_bidirct_hs>();
verify_program<test_gru_bidirct_3args_und>();
verify_program<test_gru_bidirct_3args>();
verify_program<test_gru_bidirct_seq1>();
verify_program<test_gru_bidirct_default_actv>();
verify_program<test_gru_bidirct_default_actv1>();
}
struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto ic = p.add_parameter("ic", ic_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction(
migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_lstm_forward_hs : verify_program<test_lstm_forward_hs>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto ic = p.add_parameter("ic", ic_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
return p;
}
};
struct test_lstm_forward_3args_und : verify_program<test_lstm_forward_3args_und>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
und,
und,
und,
und,
und);
return p;
}
};
struct test_lstm_forward_3args : verify_program<test_lstm_forward_3args>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_forward_seq1 : verify_program<test_lstm_forward_seq1>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_forward_default_actv : verify_program<test_lstm_forward_default_actv>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_forward_default_actv1 : verify_program<test_lstm_forward_default_actv1>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto ic = p.add_parameter("ic", ic_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_lstm_reverse_3args : verify_program<test_lstm_reverse_3args>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3args_cell_output>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
return p;
}
};
struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto ic = p.add_parameter("ic", ic_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
struct test_lstm_bidirct_3args_und : verify_program<test_lstm_bidirct_3args_und>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
und,
und,
und,
und,
und);
return p;
}
};
struct test_lstm_bidirct_3args : verify_program<test_lstm_bidirct_3args>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_bidirct_seq1 : verify_program<test_lstm_bidirct_seq1>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_bidirct_default_actv : verify_program<test_lstm_bidirct_default_actv>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_default_actv1>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_default_actv2>
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
template <int Axis>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
};
template struct test_logsoftmax<0>;
template struct test_logsoftmax<1>;
template struct test_logsoftmax<2>;
template struct test_logsoftmax<3>;
template struct test_logsoftmax<4>;
template <int Axis>
struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
};
template struct test_logsoftmax_1<0>;
template struct test_logsoftmax_1<1>;
struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data(2 * 3);
std::iota(data.begin(), data.end(), 1.0f);
auto l1 = p.add_literal(migraphx::literal(s, data));
auto l2 = p.add_parameter("p2", s);
p.add_instruction(migraphx::op::add{}, l1, l2);
migraphx::quantize(p, {"all"});
return p;
};
};
struct test_fp32_fp16_ladd : verify_program<test_fp32_fp16_ladd>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data(2 * 3);
std::iota(data.begin(), data.end(), 1.0f);
auto l1 = p.add_literal(migraphx::literal(s, data));
auto l2 = p.add_parameter("p2", s);
p.add_instruction(migraphx::op::add{}, l1, l2);
migraphx::quantize(p, {"add"});
return p;
};
};
struct test_fp32_fp16_add : verify_program<test_fp32_fp16_add>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = p.add_parameter("x", s);
auto p2 = p.add_parameter("y", s);
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2);
auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2);
p.add_instruction(migraphx::op::add{}, diff, p1);
migraphx::quantize(p, {"add"});
return p;
};
};
struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub>
{
migraphx::program create_program()
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = p.add_parameter("x", s);
auto p2 = p.add_parameter("y", s);
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2);
auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2);
p.add_instruction(migraphx::op::add{}, diff, p1);
migraphx::quantize(p, {"sub"});
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -36,6 +36,18 @@ inline std::ostream& operator<<(std::ostream& s, std::nullptr_t)
return s;
}
template <class T>
inline std::ostream& operator<<(std::ostream& s, const std::vector<T>& v)
{
s << "{ ";
for(auto&& x : v)
{
s << x << ", ";
}
s << "}";
return s;
}
template <class T, class U, class Operator>
struct expression
{
......@@ -192,10 +204,10 @@ inline void add_test_case(std::string name, std::function<void()> f)
get_test_cases().emplace_back(std::move(name), std::move(f));
}
struct auto_register
struct auto_register_test_case
{
template <class F>
auto_register(const char* name, F f) noexcept
auto_register_test_case(const char* name, F f) noexcept
{
add_test_case(name, f);
}
......@@ -258,9 +270,9 @@ inline void run(int argc, const char* argv[])
#define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__
// NOLINTNEXTLINE
#define TEST_CASE_REGISTER(...) \
static test::auto_register TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register(#__VA_ARGS__, &__VA_ARGS__);
#define TEST_CASE_REGISTER(...) \
static test::auto_register_test_case TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register_test_case(#__VA_ARGS__, &__VA_ARGS__);
// NOLINTNEXTLINE
#define TEST_CASE(...) \
......
#include <migraphx/memory_coloring.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
......@@ -18,11 +18,18 @@ struct memory_coloring_target
struct allocate
{
migraphx::shape s{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.s, "shape"));
}
std::string name() const { return "allocate"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
migraphx::check_shapes{inputs, *this}.has(1);
return inputs.front();
migraphx::check_shapes{inputs, *this}.has(0);
return s;
}
migraphx::argument compute(migraphx::context&,
const migraphx::shape& output_shape,
......@@ -34,8 +41,7 @@ struct allocate
migraphx::instruction_ref add_alloc(migraphx::program& p, const migraphx::shape& s)
{
auto a0 = p.add_outline(s);
return p.add_instruction(allocate{}, a0);
return p.add_instruction(allocate{s});
}
bool no_allocate(const migraphx::program& p)
......
add-fp16-example:m

0
12"Add test-add-fp16* 
*|B0* 
*B1Z
0


Z
1


b
2


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment