Commit 9b55685c authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Improve contiguous and concat performance (#368)

* Add env to trace nary device functions

* Formatting

* Improve contiguous and concat performance

* Formatting

* Remove unused variable

* Formatting

* Fix gpu tests

* Formatting

* Add more test for transposed concat

* Formatting

* Compute offset and not index

* Compute multi-index once

* Formatting

* Fix transposed inputs

* Formatting

* Use product order for comparisons of hip_array

* Formatting

* Add missing s parameter

* Formatting

* Dont invert permutation

* Fix tidy warnings

* Formatting

* Remove incorrect license

* Use a single integer for stride

* Formatting

* Fix tidy issue
parent 47b05b0c
...@@ -43,6 +43,12 @@ struct argument : raw_data<argument> ...@@ -43,6 +43,12 @@ struct argument : raw_data<argument>
const shape& get_shape() const { return this->m_shape; } const shape& get_shape() const { return this->m_shape; }
argument reshape(const shape& s) const
{
argument self = *this;
return {s, [=]() mutable { return self.data(); }};
}
private: private:
shape m_shape; shape m_shape;
}; };
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PERMUTATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_PERMUTATION_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Vector>
inline Vector reorder_dims(const Vector& dims, const std::vector<int64_t>& permutation)
{
Vector result(dims.size());
assert(dims.size() == permutation.size());
for(std::size_t i = 0; i < dims.size(); i++)
{
result[i] = dims[permutation[i]];
}
return result;
}
inline shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
{
return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
}
template <class Vector, class Op>
inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
return result;
}
inline std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
inline std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/permutation.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraphx { namespace migraphx {
...@@ -43,17 +44,6 @@ auto get_transpose_dims(instruction_ref ins) ...@@ -43,17 +44,6 @@ auto get_transpose_dims(instruction_ref ins)
return any_cast<const op::transpose&>(ins->get_operator()).dims; return any_cast<const op::transpose&>(ins->get_operator()).dims;
} }
std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t> permutation)
{
std::vector<int64_t> result(dims.size());
assert(dims.size() == permutation.size());
for(std::size_t i = 0; i < dims.size(); i++)
{
result[i] = dims[permutation[i]];
}
return result;
}
bool is_no_transpose(const std::vector<int64_t>& dims) bool is_no_transpose(const std::vector<int64_t>& dims)
{ {
if(dims.empty()) if(dims.empty())
...@@ -64,25 +54,6 @@ bool is_no_transpose(const std::vector<int64_t>& dims) ...@@ -64,25 +54,6 @@ bool is_no_transpose(const std::vector<int64_t>& dims)
dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end(); dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
} }
template <class Vector, class Op>
std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
return result;
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
}
struct find_reshaper struct find_reshaper
{ {
auto matcher() const auto matcher() const
......
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/gpu/device/concat.hpp> #include <migraphx/gpu/device/concat.hpp>
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/launch.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -18,16 +17,12 @@ argument concat(hipStream_t stream, ...@@ -18,16 +17,12 @@ argument concat(hipStream_t stream,
for(std::size_t j = 0; j < ninputs; j++) for(std::size_t j = 0; j < ninputs; j++)
{ {
auto&& arg = args[j]; auto&& arg = args[j];
std::size_t nelements = arg.get_shape().elements();
auto offset = offsets[j]; auto offset = offsets[j];
shape arg_shape{arg.get_shape().type(), arg.get_shape().lens()}; auto byte_offset = offset * arg.get_shape().type_size();
hip_visit_all(args.back(), arg, arg_shape)([&](auto output, auto input, auto input_shape) { auto output_shape = shape{
gs_launch(stream, nelements)([=](auto i) { arg.get_shape().type(), arg.get_shape().lens(), args.back().get_shape().strides()};
auto input_idx = input_shape.multi(i); auto output = argument{output_shape, args.back().data() + byte_offset};
auto idx = output.get_shape().index(input_idx); contiguous(stream, std::move(output), arg);
output.data()[idx + offset] = input[input_idx];
});
});
} }
return args.back(); return args.back();
} }
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void contiguous(hipStream_t stream, argument result, argument arg) void contiguous(hipStream_t stream, argument result, argument arg)
{ {
nary_nonstandard(stream, std::move(result), std::move(arg))([](auto x) { return x; }); nary(stream, std::move(result), std::move(arg))([](auto x) { return x; });
} }
} // namespace device } // namespace device
......
...@@ -9,6 +9,33 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -9,6 +9,33 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
MIGRAPHX_DEVICE_CONSTEXPR hip_array& operator op(const hip_array& x) \
{ \
for(std::size_t i = 0; i < N; i++) \
d[i] op x[i]; \
return *this; \
} \
MIGRAPHX_DEVICE_CONSTEXPR hip_array& operator op(const T& x) \
{ \
for(std::size_t i = 0; i < N; i++) \
d[i] op x; \
return *this; \
} \
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator binary_op(hip_array x, const hip_array& y) \
{ \
return x op y; \
} \
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator binary_op(hip_array x, const T& y) \
{ \
return x op y; \
} \
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator binary_op(const T& y, hip_array x) \
{ \
return x op y; \
}
template <class T, std::size_t N> template <class T, std::size_t N>
struct hip_array struct hip_array
{ {
...@@ -49,19 +76,79 @@ struct hip_array ...@@ -49,19 +76,79 @@ struct hip_array
return result; return result;
} }
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator*(const hip_array& x, const hip_array& y) MIGRAPHX_DEVICE_CONSTEXPR T single(std::size_t width = 100) const
{ {
hip_array result; T result = 0;
T a = 1;
for(std::size_t i = 0; i < N; i++) for(std::size_t i = 0; i < N; i++)
result[i] = x[i] * y[i]; {
result += d[N - i - 1] * a;
a *= width;
}
return result; return result;
} }
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator+(const hip_array& x, const hip_array& y) MIGRAPHX_DEVICE_ARRAY_OP(+=, +)
MIGRAPHX_DEVICE_ARRAY_OP(*=, *)
MIGRAPHX_DEVICE_ARRAY_OP(/=, /)
MIGRAPHX_DEVICE_ARRAY_OP(%=, %)
MIGRAPHX_DEVICE_ARRAY_OP(&=, &)
MIGRAPHX_DEVICE_ARRAY_OP(|=, |)
MIGRAPHX_DEVICE_ARRAY_OP(^=, ^)
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator==(const hip_array& x, const hip_array& y)
{
for(std::size_t i = 0; i < N; i++)
{
if(x[i] != y[i])
return false;
}
return true;
}
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator!=(const hip_array& x, const hip_array& y)
{
return !(x == y);
}
// This uses the product order rather than lexical order
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator<(const hip_array& x, const hip_array& y)
{ {
hip_array result{};
for(std::size_t i = 0; i < N; i++) for(std::size_t i = 0; i < N; i++)
result[i] = x[i] + y[i]; {
if(not(x[i] < y[i]))
return false;
}
return true;
}
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator>(const hip_array& x, const hip_array& y)
{
return y < x;
}
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator<=(const hip_array& x, const hip_array& y)
{
return (x < y) or (x == y);
}
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator>=(const hip_array& x, const hip_array& y)
{
return (y < x) or (x == y);
}
MIGRAPHX_DEVICE_CONSTEXPR hip_array carry(hip_array result) const
{
std::ptrdiff_t rem = 0;
for(std::ptrdiff_t i = result.size() - 1; i >= 0; i--)
{
auto z = result[i] + rem;
rem = z - std::ptrdiff_t(d[i]) + 1;
if(rem > 0)
z -= rem;
else
rem = 0;
result[i] = z;
}
// Add overflows to the back
if(rem > 0)
result.back() += rem;
return result; return result;
} }
}; };
......
#ifndef MIGRAPHX_GUARD_RTGLIB_MULTI_INDEX_HPP
#define MIGRAPHX_GUARD_RTGLIB_MULTI_INDEX_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <std::size_t N>
struct multi_index
{
using hip_index = hip_array<std::size_t, N>;
hip_index id{};
std::size_t stride = 0;
MIGRAPHX_DEVICE_CONSTEXPR hip_index add_stride(hip_index i) const
{
i.back() += stride;
return i;
}
template <class F>
MIGRAPHX_DEVICE_CONSTEXPR void for_stride(hip_index n, F f) const
{
for(hip_index i = id; i < n; i = n.carry(add_stride(i)))
{
f(i);
}
}
};
template <std::size_t N>
MIGRAPHX_DEVICE_CONSTEXPR multi_index<N>
make_multi_index(const hip_shape<N>& s, std::size_t i, std::size_t n)
{
return {s.multi(i), n};
}
template <std::size_t N>
MIGRAPHX_DEVICE_CONSTEXPR multi_index<N>
make_multi_index(const hip_shape<N>& s, std::size_t i, const hip_array<std::size_t, N>& n)
{
return {s.multi(i), n};
}
template <std::size_t N>
inline auto mi_launch(hipStream_t stream, const hip_shape<N>& s, std::size_t local = 1024)
{
assert(s.standard);
std::size_t n = s.elements();
std::size_t groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(128, groups) * local;
return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) {
auto midx = make_multi_index(s, idx.global, nglobal);
midx.for_stride(s.lens, [&](auto i) { f(i); });
});
};
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -2,17 +2,28 @@ ...@@ -2,17 +2,28 @@
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP #define MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/multi_index.hpp>
#include <migraphx/gpu/device/visit.hpp> #include <migraphx/gpu/device/visit.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/env.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <iostream>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY);
// NOLINTNEXTLINE
#define MIGRAPHX_TRACE_NARY_FUNCTION \
if(enabled(MIGRAPHX_TRACE_NARY{})) \
std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl;
template <class... Ts> template <class... Ts>
auto pack(Ts... xs) __device__ auto pack(Ts... xs) __device__
{ {
...@@ -20,14 +31,28 @@ auto pack(Ts... xs) __device__ ...@@ -20,14 +31,28 @@ auto pack(Ts... xs) __device__
} }
template <class F, class... Arguments> template <class F, class... Arguments>
auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments... args) auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
std::size_t nelements = result.get_shape().elements(); MIGRAPHX_TRACE_NARY_FUNCTION
hip_visit_all(result, args...)([&](auto output, auto... inputs) { shape s{result.get_shape().type(), result.get_shape().lens()};
gs_launch(stream, nelements)([=](auto i) { hip_visit_all(s, result, args...)([&](auto standard_shape, auto output, auto... inputs) {
auto idx = output.get_shape().multi(i); mi_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); });
output[i] = f(inputs[idx]...);
}); });
}
template <class F, class... Arguments>
auto nary_nonstandard_packed_impl(hipStream_t stream,
F f,
const argument& result,
Arguments... args)
{
MIGRAPHX_TRACE_NARY_FUNCTION
auto arg_shape = make_array(args...).front().get_shape();
auto perm = find_permutation(arg_shape);
auto s = reorder_shape(arg_shape, perm);
hip_visit_all(s, result.reshape(reorder_shape(result.get_shape(), perm)), args.reshape(s)...)(
[&](auto standard_shape, auto output, auto... inputs) {
mi_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); });
}); });
} }
...@@ -35,6 +60,7 @@ template <class F, class... Arguments> ...@@ -35,6 +60,7 @@ template <class F, class... Arguments>
void nary_broadcast_vec_impl( void nary_broadcast_vec_impl(
hipStream_t stream, F f, argument result, argument barg, Arguments... args) hipStream_t stream, F f, argument result, argument barg, Arguments... args)
{ {
MIGRAPHX_TRACE_NARY_FUNCTION
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = barg.get_shape(); const auto& b_shape = barg.get_shape();
auto bdim = auto bdim =
...@@ -83,6 +109,7 @@ void nary_broadcast_vec_impl( ...@@ -83,6 +109,7 @@ void nary_broadcast_vec_impl(
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args) void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args)
{ {
MIGRAPHX_TRACE_NARY_FUNCTION
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = barg.get_shape(); const auto& b_shape = barg.get_shape();
auto bdim = auto bdim =
...@@ -122,6 +149,7 @@ template <class F, class... Arguments> ...@@ -122,6 +149,7 @@ template <class F, class... Arguments>
void nary_double_broadcast_vec_impl( void nary_double_broadcast_vec_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args) hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{ {
MIGRAPHX_TRACE_NARY_FUNCTION
assert(barg1.get_shape().broadcasted()); assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted()); assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape()); assert(barg1.get_shape() == barg2.get_shape());
...@@ -179,6 +207,7 @@ template <class F, class... Arguments> ...@@ -179,6 +207,7 @@ template <class F, class... Arguments>
void nary_double_broadcast_impl( void nary_double_broadcast_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args) hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{ {
MIGRAPHX_TRACE_NARY_FUNCTION
assert(barg1.get_shape().broadcasted()); assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted()); assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape()); assert(barg1.get_shape() == barg2.get_shape());
...@@ -226,6 +255,7 @@ void nary_double_broadcast_impl( ...@@ -226,6 +255,7 @@ void nary_double_broadcast_impl(
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
MIGRAPHX_TRACE_NARY_FUNCTION
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
...@@ -250,6 +280,7 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments. ...@@ -250,6 +280,7 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
MIGRAPHX_TRACE_NARY_FUNCTION
std::size_t nelements = result.get_shape().elements(); std::size_t nelements = result.get_shape().elements();
hip_pointer_visit_all(result, args...)([&](auto output, auto... inputs) { hip_pointer_visit_all(result, args...)([&](auto output, auto... inputs) {
gs_launch(stream, nelements)([=](auto i) { output[i] = f(inputs[i]...); }); gs_launch(stream, nelements)([=](auto i) { output[i] = f(inputs[i]...); });
...@@ -259,20 +290,25 @@ void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... a ...@@ -259,20 +290,25 @@ void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... a
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); }); MIGRAPHX_TRACE_NARY_FUNCTION
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); }); const auto shapes = make_array(args.get_shape()...);
bool same_shapes = const bool standard = all_of(shapes, [](const shape& s) { return s.standard(); });
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); }); const bool packed = all_of(shapes, [](const shape& s) { return s.packed(); });
if(standard or (packed and same_shapes)) const bool same_shapes =
all_of(shapes, [&](const shape& s) { return s == result.get_shape(); });
const bool same_input_shapes = all_of(shapes, [&](const shape& s) { return s == shapes[0]; });
if((result.get_shape().standard() and standard) or (packed and same_shapes))
nary_standard_impl(stream, f, result, args...); nary_standard_impl(stream, f, result, args...);
else if(packed and same_input_shapes)
nary_nonstandard_packed_impl(stream, f, result, args...);
else else
nary_nonstandard_impl(stream, f, result, args...); nary_nonstandard_nonpacked_impl(stream, f, result, args...);
} }
template <class... Arguments> template <class... Arguments>
auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args) auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args)
{ {
return [=](auto f) { nary_nonstandard_impl(stream, f, result, args...); }; return [=](auto f) { nary_nonstandard_nonpacked_impl(stream, f, result, args...); };
} }
template <class... Arguments> template <class... Arguments>
......
...@@ -73,22 +73,6 @@ struct hip_shape ...@@ -73,22 +73,6 @@ struct hip_shape
} }
return result; return result;
} }
MIGRAPHX_DEVICE_CONSTEXPR hip_index carry(hip_index result) const
{
std::ptrdiff_t rem = 0;
for(std::ptrdiff_t i = result.size() - 1; i >= 0; i--)
{
auto z = result[i] + rem;
rem = z - std::ptrdiff_t(lens[i]) + 1;
if(rem > 0)
z -= rem;
else
rem = 0;
result[i] = z;
}
return result;
}
}; };
template <std::size_t N> template <std::size_t N>
......
...@@ -1830,6 +1830,61 @@ struct test_concat2 : verify_program<test_concat2> ...@@ -1830,6 +1830,61 @@ struct test_concat2 : verify_program<test_concat2>
} }
}; };
struct test_concat_transpose : verify_program<test_concat_transpose>
{
migraphx::program create_program() const
{
migraphx::program p;
std::size_t axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 4}};
auto l0 = p.add_parameter("x", s0);
auto lp1 = p.add_parameter("y", s1);
auto l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp1);
auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_concat_transpose2 : verify_program<test_concat_transpose2>
{
migraphx::program create_program() const
{
migraphx::program p;
std::size_t axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {5, 2}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto lp2 = p.add_parameter("z", s2);
auto l2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_concat_transpose3 : verify_program<test_concat_transpose3>
{
migraphx::program create_program() const
{
migraphx::program p;
std::size_t axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {5, 2}};
auto l0 = p.add_parameter("x", s0);
auto lp1 = p.add_parameter("y", s1);
auto l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp1);
auto lp2 = p.add_parameter("z", s2);
auto l2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_concat_relu : verify_program<test_concat_relu> struct test_concat_relu : verify_program<test_concat_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment