Commit 6ea1e1be authored by wsttiger's avatar wsttiger
Browse files

Merge branch 'master' into bn-miopen-inference

parents bf336b27 39151d27
...@@ -40,7 +40,7 @@ Of course, this program will always produce the same value which is quite uninte ...@@ -40,7 +40,7 @@ Of course, this program will always produce the same value which is quite uninte
program p; program p;
instruction_ref x = p.add_parameter("x", {shape::int64_type}); instruction_ref x = p.add_parameter("x", {shape::int64_type});
instruction_ref two = p.add_literal(2); instruction_ref two = p.add_literal(2);
p.add_instruction(add{}, one, two); p.add_instruction(add{}, x, two);
p.compile(cpu::target{}); p.compile(cpu::target{});
This adds a parameter of type ``int64``, and compiles it for the ``cpu``. To run the program, we need to pass the parameter to it when we call `eval <migraph::program::eval>`:: This adds a parameter of type ``int64``, and compiles it for the ``cpu``. To run the program, we need to pass the parameter to it when we call `eval <migraph::program::eval>`::
......
...@@ -11,6 +11,11 @@ dead_code_elimination ...@@ -11,6 +11,11 @@ dead_code_elimination
.. doxygenstruct:: migraph::dead_code_elimination .. doxygenstruct:: migraph::dead_code_elimination
auto_contiguous
---------------
.. doxygenstruct:: migraph::gpu::auto_contiguous
write_literals write_literals
-------------- --------------
......
...@@ -313,6 +313,8 @@ struct reshape ...@@ -313,6 +313,8 @@ struct reshape
struct gemm struct gemm
{ {
float alpha = 1.0;
float beta = 0.0;
std::string name() const { return "gemm"; } std::string name() const { return "gemm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -322,7 +324,8 @@ struct gemm ...@@ -322,7 +324,8 @@ struct gemm
auto t = a.type(); auto t = a.type();
if(a.lens()[1] != b.lens()[0]) if(a.lens()[1] != b.lens()[0])
MIGRAPH_THROW("Inner dimensions do not match"); MIGRAPH_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + "} x {" +
to_string_range(b.lens()) + "}");
return {t, {a.lens()[0], b.lens()[1]}}; return {t, {a.lens()[0], b.lens()[1]}};
} }
......
...@@ -71,6 +71,8 @@ struct program ...@@ -71,6 +71,8 @@ struct program
shape get_parameter_shape(std::string name); shape get_parameter_shape(std::string name);
std::unordered_map<std::string, shape> get_parameter_shapes() const;
argument eval(parameter_map params) const; argument eval(parameter_map params) const;
bool has_instruction(instruction_ref ins) const; bool has_instruction(instruction_ref ins) const;
......
...@@ -61,9 +61,12 @@ struct shape ...@@ -61,9 +61,12 @@ struct shape
std::size_t elements() const; std::size_t elements() const;
std::size_t bytes() const; std::size_t bytes() const;
/// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
/// Map multiple indices to space index
std::size_t index(const std::vector<std::size_t>& l) const; std::size_t index(const std::vector<std::size_t>& l) const;
/// Map multiple indices from a range of iterator to a space index
template <class Iterator> template <class Iterator>
std::size_t index(Iterator start, Iterator last) const std::size_t index(Iterator start, Iterator last) const
{ {
...@@ -72,12 +75,16 @@ struct shape ...@@ -72,12 +75,16 @@ struct shape
return std::inner_product(start, last, this->strides().begin(), std::size_t{0}); return std::inner_product(start, last, this->strides().begin(), std::size_t{0});
} }
// Map element index to space index /// Map element index to space index
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
/// Returns true if the shape is packed with no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending order
bool transposed() const; bool transposed() const;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool broadcasted() const; bool broadcasted() const;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and not transposed.
bool standard() const; bool standard() const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
......
...@@ -29,6 +29,7 @@ struct tensor_view ...@@ -29,6 +29,7 @@ struct tensor_view
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const const T& operator()(Ts... xs) const
{ {
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T)); assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
...@@ -36,6 +37,7 @@ struct tensor_view ...@@ -36,6 +37,7 @@ struct tensor_view
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs) T& operator()(Ts... xs)
{ {
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T)); assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
......
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#include <string>
namespace migraph {
template <class PrivateMigraphTypeNameProbe>
const std::string& get_type_name()
{
static std::string name;
if(name.empty())
{
#ifdef _MSC_VER
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
const char parameter_name[] = "PrivateMigraphTypeNameProbe =";
name = __PRETTY_FUNCTION__;
auto begin = name.find(parameter_name) + sizeof(parameter_name);
#if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7)
auto length = name.find_last_of(",") - begin;
#else
auto length = name.find_first_of("];", begin) - begin;
#endif
name = name.substr(begin, length);
#endif
}
return name;
}
template <class T>
const std::string& get_type_name(const T&)
{
return migraph::get_type_name<T>();
}
} // namespace migraph
#endif
...@@ -115,6 +115,20 @@ shape program::get_parameter_shape(std::string name) ...@@ -115,6 +115,20 @@ shape program::get_parameter_shape(std::string name)
return {}; return {};
} }
std::unordered_map<std::string, shape> program::get_parameter_shapes() const
{
std::unordered_map<std::string, shape> result;
for(auto&& ins : impl->instructions)
{
if(ins.op.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.op).parameter;
result[name] = ins.result;
}
}
return result;
}
bool program::has_instruction(instruction_ref ins) const bool program::has_instruction(instruction_ref ins) const
{ {
return std::find_if( return std::find_if(
......
...@@ -2,7 +2,14 @@ ...@@ -2,7 +2,14 @@
add_library(migraph_cpu add_library(migraph_cpu
cpu_target.cpp cpu_target.cpp
cpu_lowering.cpp cpu_lowering.cpp
gemm.cpp
) )
find_path(BLAZE_INCLUDE blaze/Blaze.h)
find_package(Threads)
rocm_clang_tidy_check(migraph_cpu) rocm_clang_tidy_check(migraph_cpu)
target_link_libraries(migraph_cpu migraph) target_link_libraries(migraph_cpu migraph Threads::Threads)
target_include_directories(migraph_cpu PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraph_cpu PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraph_cpu PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraph_cpu PRIVATE -DBLAZE_USE_CPP_THREADS)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/shape_for_each.hpp> #include <migraph/shape_for_each.hpp>
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <migraph/cpu/gemm.hpp>
#include <unordered_map> #include <unordered_map>
namespace migraph { namespace migraph {
...@@ -229,35 +230,7 @@ struct cpu_gemm ...@@ -229,35 +230,7 @@ struct cpu_gemm
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto cmat, auto amat, auto bmat) { migemm(result, args[0], args[1], op.alpha, op.beta);
auto m = amat.get_shape().lens()[0];
auto n = bmat.get_shape().lens()[1];
auto k = bmat.get_shape().lens()[0];
auto a = amat.data();
auto b = bmat.data();
auto c = cmat.data();
for(int ii = 0; ii < m; ii++)
{
for(int jj = 0; jj < n; jj++)
{
c[ii * n + jj] = 0;
}
}
for(int ii = 0; ii < m; ii++)
{
for(int kk = 0; kk < k; kk++)
{
auto aik = a[ii * k + kk];
auto* bkj = &b[kk * n];
auto* cij = &c[ii * n];
for(int jj = 0; jj < n; jj++, cij++, bkj++)
{
*cij += aik * (*bkj);
}
}
}
});
return result; return result;
} }
}; };
......
#include <migraph/cpu/gemm.hpp>
#include <migraph/dfor.hpp>
#include <migraph/requires.hpp>
#include <blaze/math/CustomMatrix.h>
namespace migraph {
namespace cpu {
template <class T>
using matrix = blaze::CustomMatrix<T, blaze::unaligned, blaze::unpadded>; // NOLINT
template <class T>
static auto make_mat(tensor_view<T> x)
{
const auto& s = x.get_shape();
assert(s.lens().size() == 2);
if(s.transposed())
return matrix<T>{x.data(), s.lens()[1], s.lens()[0], s.strides()[1]};
return matrix<T>{x.data(), s.lens()[0], s.lens()[1], s.strides()[0]};
}
template <class T, class F>
static void visit_mat(tensor_view<T> x, F f)
{
auto mat = make_mat(x);
if(x.get_shape().transposed())
f(blaze::trans(mat));
else
f(mat);
}
template <class T>
struct is_fast_gemm_type : std::false_type
{
};
template <>
struct is_fast_gemm_type<float> : std::true_type
{
};
template <class T>
void migemm_impl(tensor_view<T> cmat,
tensor_view<T> amat,
tensor_view<T> bmat,
float alpha,
float beta,
std::true_type)
{
visit_mat(amat, [&](const auto& a) {
visit_mat(bmat, [&](const auto& b) {
auto c = make_mat(cmat);
c = (a * b) * alpha + beta * c;
});
});
}
template <class T>
void migemm_impl(tensor_view<T> cmat,
tensor_view<T> amat,
tensor_view<T> bmat,
float alpha,
float beta,
std::false_type)
{
auto m = cmat.get_shape().lens()[0];
auto n = cmat.get_shape().lens()[1];
auto k = amat.get_shape().lens()[1];
assert(amat.get_shape().lens()[1] == bmat.get_shape().lens()[0]);
assert(m == amat.get_shape().lens()[0]);
assert(n == bmat.get_shape().lens()[1]);
dfor(m, n)([&](auto ii, auto jj) {
double s = cmat(ii, jj) * beta;
dfor(k)([&](auto kk) { s += amat(ii, kk) * bmat(kk, jj); });
cmat(ii, jj) = alpha * s;
});
}
template <class T>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
}
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta)
{
visit_all(c_arg, a_arg, b_arg)(
[&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); });
}
} // namespace cpu
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_CPU_GEMM_HPP
#define MIGRAPH_GUARD_RTGLIB_CPU_GEMM_HPP
#include <migraph/argument.hpp>
namespace migraph {
namespace cpu {
void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta);
} // namespace cpu
} // namespace migraph
#endif
...@@ -202,15 +202,17 @@ struct miopen_gemm ...@@ -202,15 +202,17 @@ struct miopen_gemm
{ {
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
rocblas_int lda = args[0].get_shape().lens()[1]; bool transa = args[0].get_shape().transposed();
rocblas_int ldb = args[1].get_shape().lens()[1]; bool transb = args[1].get_shape().transposed();
rocblas_int ldc = args[2].get_shape().lens()[1]; rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
rocblas_int ldc = args[2].get_shape().strides()[0];
rocblas_int m = output_shape.lens()[0]; rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1]; rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1]; rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(ctx.rbhandle.get(), rocblas_sgemm(ctx.rbhandle.get(),
rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
......
...@@ -242,48 +242,49 @@ void reshape_test() ...@@ -242,48 +242,49 @@ void reshape_test()
} }
} }
template <class T>
void gemm_test() void gemm_test()
{ {
migraph::program p; migraph::program p;
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885, std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027, 1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632, -0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814}; -1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01, std::vector<T> b = {6.09568541e-01,
-6.10527007e-01, -6.10527007e-01,
3.66646462e-01, 3.66646462e-01,
1.18951101e-01, 1.18951101e-01,
5.58777432e-01, 5.58777432e-01,
-3.21296298e-01, -3.21296298e-01,
-5.95997198e-01, -5.95997198e-01,
-5.01425721e-01, -5.01425721e-01,
-2.84606807e-01, -2.84606807e-01,
-5.73673557e-01, -5.73673557e-01,
-8.99430260e-01, -8.99430260e-01,
-4.25103093e-01, -4.25103093e-01,
1.53027987e+00, 1.53027987e+00,
-3.81407415e-04, -3.81407415e-04,
-3.29650255e-01}; -3.29650255e-01};
std::vector<float> c = {-1.56327541e+00, std::vector<T> c = {-1.56327541e+00,
-7.09570140e-01, -7.09570140e-01,
-5.37424982e-01, -5.37424982e-01,
-2.22994831e-01, -2.22994831e-01,
-2.15586437e+00, -2.15586437e+00,
2.09177941e-03, 2.09177941e-03,
-1.47279677e+00, -1.47279677e+00,
2.02627040e-01, 2.02627040e-01,
-6.04527691e-01, -6.04527691e-01,
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
migraph::shape a_shape{migraph::shape::float_type, {4, 5}}; migraph::shape a_shape{migraph::shape::get_type<T>{}, {4, 5}};
auto al = p.add_literal(migraph::literal{a_shape, a}); auto al = p.add_literal(migraph::literal{a_shape, a});
migraph::shape b_shape{migraph::shape::float_type, {5, 3}}; migraph::shape b_shape{migraph::shape::get_type<T>{}, {5, 3}};
auto bl = p.add_literal(migraph::literal{b_shape, b}); auto bl = p.add_literal(migraph::literal{b_shape, b});
p.add_instruction(migraph::gemm{}, al, bl); p.add_instruction(migraph::gemm{}, al, bl);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(12); std::vector<T> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
for(int i = 0; i < results_vector.size(); i++) for(int i = 0; i < results_vector.size(); i++)
...@@ -656,7 +657,8 @@ int main() ...@@ -656,7 +657,8 @@ int main()
add_broadcast_test(); add_broadcast_test();
sub_test(); sub_test();
mul_test(); mul_test();
gemm_test(); gemm_test<float>();
gemm_test<double>();
reshape_test(); reshape_test();
transpose_test(); transpose_test();
contiguous_test(); contiguous_test();
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraph/gpu/miopen.hpp> #include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp> #include <migraph/gpu/hip.hpp>
#include <migraph/manage_ptr.hpp> #include <migraph/manage_ptr.hpp>
#include <migraph/type_name.hpp>
#include <miopen/miopen.h> #include <miopen/miopen.h>
...@@ -19,7 +20,12 @@ migraph::argument run_cpu() ...@@ -19,7 +20,12 @@ migraph::argument run_cpu()
V v; V v;
auto p = v.create_program(); auto p = v.create_program();
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
return p.eval(v.create_params()); migraph::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraph::generate_argument(x.second);
}
return p.eval(m);
} }
template <class V> template <class V>
...@@ -29,14 +35,12 @@ migraph::argument run_gpu() ...@@ -29,14 +35,12 @@ migraph::argument run_gpu()
auto p = v.create_program(); auto p = v.create_program();
p.compile(migraph::gpu::target{}); p.compile(migraph::gpu::target{});
auto m = v.create_params(); migraph::program::parameter_map m;
for(auto&& e : m) for(auto&& x : p.get_parameter_shapes())
{ {
e.second = migraph::gpu::to_gpu(e.second); m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
} }
m["output"] = migraph::gpu::to_gpu(migraph::generate_argument(p.get_parameter_shape("output")));
return migraph::gpu::from_gpu(p.eval(m)); return migraph::gpu::from_gpu(p.eval(m));
} }
...@@ -45,7 +49,12 @@ void verify_program() ...@@ -45,7 +49,12 @@ void verify_program()
{ {
auto cpu_arg = run_cpu<V>(); auto cpu_arg = run_cpu<V>();
auto gpu_arg = run_gpu<V>(); auto gpu_arg = run_gpu<V>();
visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) { EXPECT(test::verify_range(cpu, gpu)); }); visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) {
if(not test::verify_range(cpu, gpu))
{
std::cout << "FAILED: " << migraph::get_type_name<V>() << std::endl;
}
});
} }
struct test_literals struct test_literals
...@@ -61,8 +70,6 @@ struct test_literals ...@@ -61,8 +70,6 @@ struct test_literals
p.add_instruction(migraph::activation{"relu"}, conv); p.add_instruction(migraph::activation{"relu"}, conv);
return p; return p;
} }
migraph::program::parameter_map create_params() const { return {}; }
}; };
struct test_add struct test_add
...@@ -76,14 +83,6 @@ struct test_add ...@@ -76,14 +83,6 @@ struct test_add
p.add_instruction(migraph::add{}, x, y); p.add_instruction(migraph::add{}, x, y);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {3}});
m["y"] = migraph::generate_argument({migraph::shape::float_type, {3}});
return m;
}
}; };
struct test_add_broadcast struct test_add_broadcast
...@@ -98,14 +97,6 @@ struct test_add_broadcast ...@@ -98,14 +97,6 @@ struct test_add_broadcast
p.add_instruction(migraph::add{}, x, by); p.add_instruction(migraph::add{}, x, by);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {2, 2, 3}});
m["y"] = migraph::generate_argument({migraph::shape::float_type, {2, 2}});
return m;
}
}; };
struct test_conv_relu struct test_conv_relu
...@@ -120,14 +111,6 @@ struct test_conv_relu ...@@ -120,14 +111,6 @@ struct test_conv_relu
p.add_instruction(migraph::activation{"relu"}, conv); p.add_instruction(migraph::activation{"relu"}, conv);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 3, 3}});
m["w"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 3, 3}});
return m;
}
}; };
struct test_conv_pooling struct test_conv_pooling
...@@ -144,14 +127,6 @@ struct test_conv_pooling ...@@ -144,14 +127,6 @@ struct test_conv_pooling
p.add_instruction(migraph::activation{"relu"}, pooling); p.add_instruction(migraph::activation{"relu"}, pooling);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 32, 32}});
m["w"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 3, 3}});
return m;
}
}; };
struct test_gemm struct test_gemm
...@@ -164,13 +139,57 @@ struct test_gemm ...@@ -164,13 +139,57 @@ struct test_gemm
p.add_instruction(migraph::gemm{}, a, b); p.add_instruction(migraph::gemm{}, a, b);
return p; return p;
} }
};
migraph::program::parameter_map create_params() const struct test_gemm_ld
{
migraph::program create_program() const
{ {
migraph::program::parameter_map m; migraph::program p;
m["a"] = migraph::generate_argument({migraph::shape::float_type, {4, 5}}); auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}, {10, 1}});
m["b"] = migraph::generate_argument({migraph::shape::float_type, {5, 3}}); auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}, {20, 1}});
return m; p.add_instruction(migraph::gemm{}, a, b);
return p;
}
};
struct test_gemm_transposeb
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}});
auto bt = p.add_instruction(migraph::transpose{{1, 0}}, b);
p.add_instruction(migraph::gemm{}, a, bt);
return p;
}
};
struct test_gemm_transposea
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}});
auto at = p.add_instruction(migraph::transpose{{1, 0}}, a);
p.add_instruction(migraph::gemm{}, at, b);
return p;
}
};
struct test_gemm_transposeab
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}});
auto at = p.add_instruction(migraph::transpose{{1, 0}}, a);
auto bt = p.add_instruction(migraph::transpose{{1, 0}}, b);
p.add_instruction(migraph::gemm{}, at, bt);
return p;
} }
}; };
...@@ -184,14 +203,6 @@ struct test_contiguous ...@@ -184,14 +203,6 @@ struct test_contiguous
p.add_instruction(migraph::contiguous{}, x); p.add_instruction(migraph::contiguous{}, x);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] =
migraph::generate_argument({migraph::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}});
return m;
}
}; };
struct test_transpose struct test_transpose
...@@ -206,13 +217,6 @@ struct test_transpose ...@@ -206,13 +217,6 @@ struct test_transpose
p.add_instruction(migraph::contiguous{}, l); p.add_instruction(migraph::contiguous{}, l);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 4, 4}});
return m;
}
}; };
struct test_batchnorm_inference struct test_batchnorm_inference
...@@ -301,6 +305,10 @@ int main() ...@@ -301,6 +305,10 @@ int main()
verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
verify_program<test_conv_pooling>(); verify_program<test_conv_pooling>();
verify_program<test_gemm>(); verify_program<test_gemm>();
// verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposeab>();
verify_program<test_contiguous>(); verify_program<test_contiguous>();
verify_program<test_transpose>(); verify_program<test_transpose>();
verify_program<test_batchnorm_inference>(); verify_program<test_batchnorm_inference>();
......
#include <migraph/type_name.hpp>
#include "test.hpp"
struct global_class
{
struct inner_class
{
};
};
namespace foo {
struct ns_class
{
struct inner_class
{
};
};
} // namespace foo
int main()
{
EXPECT(migraph::get_type_name<global_class>() == "global_class");
EXPECT(migraph::get_type_name<global_class::inner_class>() == "global_class::inner_class");
EXPECT(migraph::get_type_name<foo::ns_class>() == "foo::ns_class");
EXPECT(migraph::get_type_name<foo::ns_class::inner_class>() == "foo::ns_class::inner_class");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment