Commit 9b53cf55 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'extend_gemm_op' into seq2seq_example

parents 0db15370 ad8f88f5
#ifndef MIGRAPHX_GUARD_RTGLIB_RELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_RELU_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_relu
{
shared<activation_descriptor> ad;
......
......@@ -2,7 +2,6 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#include <migraphx/manage_ptr.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
#include <rocblas.h>
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SIGMOID_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIGMOID_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_sigmoid
{
shared<activation_descriptor> ad;
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SIN_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIN_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/sin.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SINH_HPP
#define MIGRAPHX_GUARD_RTGLIB_SINH_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/sinh.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_softmax
{
op::softmax op;
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SUB_HPP
#define MIGRAPHX_GUARD_RTGLIB_SUB_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/sub.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#ifndef MIGRAPHX_GUARD_RTGLIB_TAN_HPP
#define MIGRAPHX_GUARD_RTGLIB_TAN_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/tan.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#ifndef MIGRAPHX_GUARD_RTGLIB_TANH_HPP
#define MIGRAPHX_GUARD_RTGLIB_TANH_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_tanh
{
shared<activation_descriptor> ad;
......
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/lrn.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/pad.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/pad.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/pooling.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/relu.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/sigmoid.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/softmax.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/tanh.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -173,7 +173,7 @@ TEST_CASE(gather_test)
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{0};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
......@@ -194,7 +194,7 @@ TEST_CASE(gather_test)
migraphx::shape s{migraphx::shape::float_type, {3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{0};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
......@@ -925,6 +925,55 @@ void gemm_test()
TEST_CASE_REGISTER(gemm_test<float>)
TEST_CASE_REGISTER(gemm_test<double>)
template <class T>
void gemm_test_ex()
{
migraphx::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {1, 1, 4, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {1, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<T> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(gemm_test_ex<float>)
TEST_CASE_REGISTER(gemm_test_ex<double>)
TEST_CASE(maxpool_test)
{
migraphx::program p;
......
......@@ -746,6 +746,18 @@ struct test_gemm
}
};
struct test_gemm_ex
{
migraphx::program create_program() const
{
migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 5}});
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}});
p.add_instruction(migraphx::op::dot{}, a, b);
return p;
}
};
struct test_gemm_half
{
migraphx::program create_program() const
......@@ -785,6 +797,19 @@ struct test_gemm_transposeb
}
};
struct test_gemm_transposeb_ex
{
migraphx::program create_program() const
{
migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}});
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}});
auto bt = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, b);
p.add_instruction(migraphx::op::dot{}, a, bt);
return p;
}
};
struct test_gemm_transposea
{
migraphx::program create_program() const
......@@ -798,6 +823,19 @@ struct test_gemm_transposea
}
};
struct test_gemm_transposea_ex
{
migraphx::program create_program() const
{
migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}});
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}});
auto at = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, a);
p.add_instruction(migraphx::op::dot{}, at, b);
return p;
}
};
struct test_gemm_transposeab
{
migraphx::program create_program() const
......@@ -1074,7 +1112,7 @@ struct test_gather_scalar_output
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
......@@ -1090,7 +1128,7 @@ struct test_gather_scalar_index
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
......@@ -2936,10 +2974,13 @@ int main()
verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>();
verify_program<test_gemm>();
verify_program<test_gemm_ex>();
verify_program<test_gemm_half>();
// verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposeb_ex>();
verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposea_ex>();
verify_program<test_gemm_transposeab>();
verify_program<test_contiguous>();
verify_program<test_eliminate_contiguous>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment