Commit 2f268bc2 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into mlir-c

parents f75c5a38 aa7ff911
...@@ -83,11 +83,12 @@ struct shape ...@@ -83,11 +83,12 @@ struct shape
} }
} }
/// Convert single index into a multi-index
constexpr index_array multi(index_int idx) const constexpr index_array multi(index_int idx) const
{ {
index_array result; index_array result;
index_int tidx = idx; index_int tidx = idx;
for(std::ptrdiff_t is = result.size() - 1; is > 0; is--) for(diff_int is = result.size() - 1; is > 0; is--)
{ {
result[is] = tidx % lens[is]; result[is] = tidx % lens[is];
tidx = tidx / lens[is]; tidx = tidx / lens[is];
...@@ -95,6 +96,13 @@ struct shape ...@@ -95,6 +96,13 @@ struct shape
result[0] = tidx; result[0] = tidx;
return result; return result;
} }
/// Convert multi-index into a single index
constexpr index_int single(index_array idx) const
{
if(idx.empty())
return 0;
return inner_product(lens.begin() + 1, lens.end(), idx.begin(), idx.back());
}
constexpr shape get_shape() const { return *this; } constexpr shape get_shape() const { return *this; }
......
...@@ -11,7 +11,7 @@ template <class T> ...@@ -11,7 +11,7 @@ template <class T>
struct tensor_view_iterator_read struct tensor_view_iterator_read
{ {
T* view; T* view;
constexpr auto& operator()(std::size_t n) const constexpr auto& operator()(index_int n) const
{ {
MIGRAPHX_ASSERT(view != nullptr); MIGRAPHX_ASSERT(view != nullptr);
return (*view)[n]; return (*view)[n];
...@@ -21,18 +21,31 @@ struct tensor_view_iterator_read ...@@ -21,18 +21,31 @@ struct tensor_view_iterator_read
template <class T, class Shape> template <class T, class Shape>
struct tensor_view struct tensor_view
{ {
using type = T; using type = T;
using shape_type = Shape; using shape_type = Shape;
using iterator = basic_iota_iterator<tensor_view_iterator_read<const tensor_view>, index_int>; using index_array = typename Shape::index_array;
using iterator = basic_iota_iterator<tensor_view_iterator_read<const tensor_view>, index_int>;
constexpr Shape get_shape() const { return Shape{}; } constexpr Shape get_shape() const { return Shape{}; }
constexpr auto size() const { return get_shape().elements(); } constexpr auto size() const { return get_shape().elements(); }
template <class U> struct index_to_offset
constexpr T& operator[](U i) const
{ {
MIGRAPHX_ASSERT(get_shape().index(i) < get_shape().element_space()); index_int offset;
return x[get_shape().index(i)]; template <class U>
constexpr index_to_offset(U i) : offset(Shape{}.index(i))
{
}
};
constexpr T& operator[](MIGRAPHX_CAPTURE_SOURCE_LOCATION(index_to_offset) i) const
{
index_to_offset ito = i;
MIGRAPHX_WARN(ito.offset < get_shape().element_space(),
i,
"Out of bounds access at offset: ",
ito.offset);
return x[ito.offset];
} }
constexpr T* data() const { return x; } constexpr T* data() const { return x; }
...@@ -40,6 +53,13 @@ struct tensor_view ...@@ -40,6 +53,13 @@ struct tensor_view
constexpr auto begin() const { return iterator{0, {this}}; } constexpr auto begin() const { return iterator{0, {this}}; }
constexpr auto end() const { return iterator{this->size(), {this}}; } constexpr auto end() const { return iterator{this->size(), {this}}; }
constexpr auto begin_at(index_array i) const
{
MIGRAPHX_ASSERT(get_shape().single(i) < get_shape().elements());
MIGRAPHX_ASSERT(get_shape().index(i) < get_shape().element_space());
return iterator{get_shape().single(i), {this}};
}
template <class U> template <class U>
constexpr tensor_view<U, Shape> with(U* y) const constexpr tensor_view<U, Shape> with(U* y) const
{ {
...@@ -50,6 +70,9 @@ struct tensor_view ...@@ -50,6 +70,9 @@ struct tensor_view
T* x; T* x;
}; };
template <class T>
using get_shape_c = typename T::shape_type;
template <class T, class Shape> template <class T, class Shape>
constexpr tensor_view<T, Shape> make_tensor_view(T* x, Shape) constexpr tensor_view<T, Shape> make_tensor_view(T* x, Shape)
{ {
......
...@@ -6,6 +6,15 @@ ...@@ -6,6 +6,15 @@
namespace migraphx { namespace migraphx {
template <class T, class U = T&&>
U private_declval(int);
template <class T>
T private_declval(long);
template <class T>
auto declval() noexcept -> decltype(private_declval<T>(0));
template <class T> template <class T>
struct type_identity struct type_identity
{ {
...@@ -26,20 +35,88 @@ struct enable_if<true, T> ...@@ -26,20 +35,88 @@ struct enable_if<true, T>
template <bool B, class T = void> template <bool B, class T = void>
using enable_if_t = typename enable_if<B, T>::type; using enable_if_t = typename enable_if<B, T>::type;
template <class From, class To> template <bool B, class T, class F>
struct is_convertible : bool_constant<__is_convertible(From, To)> struct conditional
{ {
using type = T;
}; };
template <class T, class U> template <class T, class F>
struct is_same : false_type struct conditional<false, T, F>
{ {
using type = F;
}; };
template <class T> template <bool B, class T, class F>
struct is_same<T, T> : true_type using conditional_t = typename conditional<B, T, F>::type;
{
}; // NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
struct name : bool_constant<__##name(T)> \
{ \
}
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
struct name : bool_constant<__##name(T, U)> \
{ \
}
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAITN(name) \
template <class... Ts> \
struct name : bool_constant<__##name(Ts...)> \
{ \
}
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_arithmetic);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_destructible);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_nothrow_destructible);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_pointer);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_scalar);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_signed);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_void);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_abstract);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_aggregate);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_array);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_class);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_compound);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_const);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_empty);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_enum);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_final);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_floating_point);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_function);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_fundamental);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_integral);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_literal_type);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_lvalue_reference);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_function_pointer);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_object_pointer);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_pointer);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_object);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_pod);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_polymorphic);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_reference);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_rvalue_reference);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_standard_layout);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivial);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivially_copyable);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivially_destructible);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_union);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_unsigned);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_volatile);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_assignable);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_base_of);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_convertible);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_nothrow_assignable);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_same);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_trivially_assignable);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible);
template <class T> template <class T>
struct remove_reference struct remove_reference
...@@ -68,6 +145,68 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*> ...@@ -68,6 +145,68 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template <class T> template <class T>
using add_pointer_t = typename add_pointer<T>::type; using add_pointer_t = typename add_pointer<T>::type;
template <class... Ts>
struct common_type;
template <class T>
struct common_type<T>
{
using type = T;
};
template <class T, class U>
struct common_type<T, U>
{
using type = decltype(true ? declval<T>() : declval<U>());
};
template <class T, class U, class... Us>
struct common_type<T, U, Us...>
{
using type = typename common_type<typename common_type<T, U>::type, Us...>::type;
};
template <class... Ts>
using common_type_t = typename common_type<Ts...>::type;
constexpr unsigned long int_max(unsigned long n) { return (1u << (n * 8)) - 1; }
template <class T>
constexpr T numeric_max()
{
if constexpr(is_integral<T>{})
{
if constexpr(is_unsigned<T>{})
return int_max(sizeof(T)) * 2;
else
return int_max(sizeof(T));
}
else if constexpr(is_same<T, double>{})
return __DBL_MAX__;
else if constexpr(is_same<T, float>{})
return __FLT_MAX__;
else if constexpr(is_same<T, migraphx::half>{})
return __FLT16_MAX__;
else
return 0;
}
template <class T>
constexpr T numeric_lowest()
{
if constexpr(is_integral<T>{})
{
if constexpr(is_unsigned<T>{})
return 0;
else
return -numeric_max<T>() - 1;
}
else
{
return -numeric_max<T>();
}
}
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__> #define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx } // namespace migraphx
......
...@@ -13,7 +13,8 @@ using diff_int = std::int32_t; ...@@ -13,7 +13,8 @@ using diff_int = std::int32_t;
template <class T, index_int N> template <class T, index_int N>
using vec = T __attribute__((ext_vector_type(N))); using vec = T __attribute__((ext_vector_type(N)));
using half = _Float16; using half = _Float16;
using half2 = migraphx::vec<half, 2>;
} // namespace migraphx } // namespace migraphx
......
...@@ -46,6 +46,9 @@ constexpr auto vec_at(T x, I i) ...@@ -46,6 +46,9 @@ constexpr auto vec_at(T x, I i)
} }
} }
template <class T>
using vec_type = decltype(vec_at(T{}, 0));
template <class... Ts> template <class... Ts>
constexpr auto common_vec_size() constexpr auto common_vec_size()
{ {
...@@ -57,17 +60,26 @@ constexpr auto common_vec_size() ...@@ -57,17 +60,26 @@ constexpr auto common_vec_size()
})(vec_size<Ts>()...); })(vec_size<Ts>()...);
} }
// Bools can not be used as a vector type so convert it to uint8
template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T> template <index_int N, class T>
__device__ __host__ auto as_vec(T* x) __device__ __host__ auto as_vec(T* x)
{ {
if constexpr(N == 0) if constexpr(N < 2)
return x; return x;
else else
return reinterpret_cast<vec<T, N>*>(x); return reinterpret_cast<vec<T, N>*>(x);
} }
template <class T, index_int N> template <class T, index_int N>
using safe_vec = vec<std::conditional_t<std::is_same<T, bool>{}, uint8_t, T>, N>; using safe_vec = vec<conditional_t<is_same<T, bool>{}, uint8_t, T>, N>;
template <class... Ts> template <class... Ts>
constexpr auto vec_transform(Ts... xs) constexpr auto vec_transform(Ts... xs)
...@@ -89,5 +101,64 @@ constexpr auto vec_transform(Ts... xs) ...@@ -89,5 +101,64 @@ constexpr auto vec_transform(Ts... xs)
}; };
} }
// Return a vector type of N from index i in another larger vector
// N will be 2 for half2 packing
template <index_int N, class T, class I>
constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i)
{
if constexpr(vec_size<T>() == 0)
return vec<T, N>{x};
else
{
MIGRAPHX_ASSERT((i + N) < vec_size<T>());
vec<vec_type<T>, N> result = {0};
for(int j = 0; j < N; j++)
{
result[j] = x[i + j];
}
return result;
}
}
template <index_int N, class... Ts>
constexpr auto vec_packed_transform(Ts... xs)
{
return [=](auto f) {
if constexpr(is_any_vec<Ts...>())
{
using type = vec_type<decltype(f(vec_packed_at<N>(xs, 0)...))>;
constexpr auto size = common_vec_size<Ts...>();
safe_vec<type, size> result = {0};
for(int i = 0; i < size / N; i++)
{
// Call the function with packed vectors
safe_vec<type, N> r = f(vec_packed_at<N>(xs, i * N)...);
// Copy the packed vectors to the result
for(int j = 0; j < N; j++)
result[i * N + j] = r[j];
}
return result;
}
else
{
return f(xs...);
}
};
}
template <class T, class Op>
constexpr auto vec_reduce(T x, Op op)
{
if constexpr(vec_size<T>() < 2)
return x;
else
{
vec_type<T> result = x[0];
for(int i = 1; i < vec_size<T>(); i++)
result = op(result, x[i]);
return result;
}
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP #endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
...@@ -50,19 +50,10 @@ constexpr auto shape_step(Shape s, Axis) ...@@ -50,19 +50,10 @@ constexpr auto shape_step(Shape s, Axis)
}); });
} }
// Bools can not be used as a vector type so convert it to uint8
template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T, class Axis> template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis) __device__ __host__ auto as_vec(T x, Axis axis)
{ {
if constexpr(N == 0) if constexpr(N < 2)
return x; return x;
else else
return make_tensor_view(as_vec<N>(remove_bool(x.data())), return make_tensor_view(as_vec<N>(remove_bool(x.data())),
...@@ -72,7 +63,7 @@ __device__ __host__ auto as_vec(T x, Axis axis) ...@@ -72,7 +63,7 @@ __device__ __host__ auto as_vec(T x, Axis axis)
template <index_int N, class T, class Axis> template <index_int N, class T, class Axis>
constexpr auto tensor_step(T x, Axis axis) constexpr auto tensor_step(T x, Axis axis)
{ {
if constexpr(N == 0) if constexpr(N < 2)
{ {
return x; return x;
} }
...@@ -157,11 +148,11 @@ constexpr auto find_vectorize_size(P pred) ...@@ -157,11 +148,11 @@ constexpr auto find_vectorize_size(P pred)
else if constexpr(decltype(pred(_c<2>)){}) else if constexpr(decltype(pred(_c<2>)){})
return _c<2>; return _c<2>;
else else
return _c<0>; return _c<1>;
} }
template <class T> template <class T>
__host__ __device__ auto vectorize(T x) __host__ __device__ auto auto_vectorize(T x)
{ {
if constexpr(tensor_vec_size<T>() == 0) if constexpr(tensor_vec_size<T>() == 0)
{ {
...@@ -194,7 +185,7 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs) ...@@ -194,7 +185,7 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
{ {
MIGRAPHX_ASSERT(s.strides[axis] == 0 or s.strides[axis] == 1); MIGRAPHX_ASSERT(s.strides[axis] == 0 or s.strides[axis] == 1);
MIGRAPHX_ASSERT(s.lens[axis] > 0); MIGRAPHX_ASSERT(s.lens[axis] > 0);
MIGRAPHX_ASSERT(n == 0 or s.lens[axis] % n == 0); MIGRAPHX_ASSERT(n == 1 or s.lens[axis] % n == 0);
if constexpr(s.strides[axis] == 0) if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis); return tensor_step<n>(x, axis);
else else
...@@ -215,7 +206,34 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs) ...@@ -215,7 +206,34 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
inline __device__ __host__ auto auto_vectorize() inline __device__ __host__ auto auto_vectorize()
{ {
return [](auto... xs) { return [=](auto f) { auto_vectorize_impl(f, xs...); }; }; return make_transform([](auto f, auto... xs) { auto_vectorize_impl(f, xs...); });
}
template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x)
{
constexpr auto shape = get_shape_c<T>{};
if constexpr(shape.lens[Axis] == 1)
return x;
else if constexpr(shape.strides[Axis] == 0)
return tensor_step<N>(x, _c<Axis>);
else
return as_vec<N>(x, _c<Axis>);
}
template <index_int N, index_int Axis>
__device__ __host__ auto vectorize()
{
return make_transform([](auto f, auto... xs) {
if constexpr(N < 2)
{
f(xs...);
}
else
{
f(vectorize_tensor<N, Axis>(xs)...);
}
});
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -181,16 +181,11 @@ struct miopen_apply ...@@ -181,16 +181,11 @@ struct miopen_apply
add_extend_op("pad"); add_extend_op("pad");
add_extend_op("pooling"); add_extend_op("pooling");
add_extend_op("prefix_scan_sum"); add_extend_op("prefix_scan_sum");
add_extend_op("reduce_max");
add_extend_op("reduce_mean");
add_extend_op("reduce_min");
add_extend_op("reduce_prod");
add_extend_op("reduce_sum");
add_extend_op("reverse"); add_extend_op("reverse");
add_extend_op("rnn_var_sl_last_output"); add_extend_op("rnn_var_sl_last_output");
add_extend_op("rnn_var_sl_shift_output"); add_extend_op("rnn_var_sl_shift_output");
add_extend_op("rnn_var_sl_shift_sequence"); add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op("scatter"); add_extend_op("scatter_none");
add_extend_op("softmax"); add_extend_op("softmax");
add_extend_op("topk"); add_extend_op("topk");
...@@ -370,8 +365,22 @@ struct miopen_apply ...@@ -370,8 +365,22 @@ struct miopen_apply
{ {
apply_map.emplace("quant_convolution", [=](instruction_ref ins) { apply_map.emplace("quant_convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator()); auto&& op = any_cast<op::quant_convolution>(ins->get_operator());
auto conv = miopen_quant_convolution{op, make_conv(op)}; shape ws;
auto ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs())); miopen_quant_convolution conv;
auto compile_quant_conv_with_format = [&](bool format) {
conv = miopen_quant_convolution{op, format, make_conv(op)};
ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
};
try
{
compile_quant_conv_with_format(int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
compile_quant_conv_with_format(!int8_x4_format);
}
auto args = ins->inputs(); auto args = ins->inputs();
auto workspace = insert_allocation(ins, ws, "workspace"); auto workspace = insert_allocation(ins, ws, "workspace");
...@@ -381,6 +390,9 @@ struct miopen_apply ...@@ -381,6 +390,9 @@ struct miopen_apply
}); });
} }
// add_generic_op just constructs the operator with no fields whereas add_extend_op copies over
// the fields Since it doesn't have fields its default constructed
void add_generic_op(const std::string& name) { add_generic_op(name, "gpu::" + name); } void add_generic_op(const std::string& name) { add_generic_op(name, "gpu::" + name); }
void add_generic_op(const std::string& op_name, const std::string& gpu_name) void add_generic_op(const std::string& op_name, const std::string& gpu_name)
......
...@@ -22,10 +22,10 @@ static instruction_ref pad_ins(module& m, instruction_ref ins, int offset) ...@@ -22,10 +22,10 @@ static instruction_ref pad_ins(module& m, instruction_ref ins, int offset)
auto pad_k = (k + 3) / 4 * 4; auto pad_k = (k + 3) / 4 * 4;
auto pad_lens = lens; auto pad_lens = lens;
pad_lens[lens.size() + offset] = pad_k; pad_lens[lens.size() + offset] = pad_k;
std::vector<int64_t> pad_dims(lens.size() * 2, 0); auto ret_ins = ins;
auto ret_ins = ins;
if(pad_k != k) if(pad_k != k)
{ {
std::vector<int64_t> pad_dims(lens.size() * 2, 0);
pad_dims[lens.size() + offset] = pad_k - k; pad_dims[lens.size() + offset] = pad_k - k;
shape ps{s.type(), pad_lens}; shape ps{s.type(), pad_lens};
auto ins_out = auto ins_out =
...@@ -118,7 +118,7 @@ void pack_int8_args::apply(module& m) const ...@@ -118,7 +118,7 @@ void pack_int8_args::apply(module& m) const
assert(val.contains("int8_x4_format")); assert(val.contains("int8_x4_format"));
if(not val.at("int8_x4_format").to<bool>()) if(not val.at("int8_x4_format").to<bool>())
{ {
return; continue;
} }
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto lens = inputs.at(0)->get_shape().lens(); auto lens = inputs.at(0)->get_shape().lens();
...@@ -156,6 +156,12 @@ void pack_int8_args::apply(module& m) const ...@@ -156,6 +156,12 @@ void pack_int8_args::apply(module& m) const
} }
else if(ins->name() == "gpu::quant_convolution") else if(ins->name() == "gpu::quant_convolution")
{ {
auto val = ins->get_operator().to_value();
if(not val.at("int8_x4_format").to<bool>())
{
continue;
}
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto packed_x = m.insert_instruction( auto packed_x = m.insert_instruction(
ins, ins,
......
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace {
struct find_layernorm
{
auto matcher() const { return match::layernorm(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
if(not x_ins->get_shape().standard())
x_ins = m.insert_instruction(ins, make_op("contiguous"), x_ins);
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::layernorm"), x_ins, a);
}
};
struct find_triaddlayernorm
{
auto matcher() const
{
auto add1 =
match::name("add")(match::none_of(match::is_constant()),
match::args(match::any().bind("z1"), match::any().bind("z2")));
auto add2 = match::name("add")(match::either_arg(0, 1)(add1, match::any().bind("z3")));
return match::layernorm()(match::var("x")(add2));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["z1"];
auto y_ins = r.instructions["z2"];
auto z_ins = r.instructions["z3"];
for(auto* pins : {&x_ins, &y_ins, &z_ins})
{
if(not(*pins)->get_shape().standard())
*pins = m.insert_instruction(ins, make_op("contiguous"), *pins);
}
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::triadd_layernorm"), x_ins, y_ins, z_ins, a);
}
};
} // namespace
void prefuse_ops::apply(module& m) const
{
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -16,8 +16,8 @@ argument miopen_quant_convolution::compute(context& ctx, ...@@ -16,8 +16,8 @@ argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
auto x_desc = make_tensor(args[0].get_shape(), true); auto x_desc = make_tensor(args[0].get_shape(), int8_x4_format);
auto w_desc = make_tensor(args[1].get_shape(), true); auto w_desc = make_tensor(args[1].get_shape(), int8_x4_format);
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1; float alpha = 1;
...@@ -49,8 +49,8 @@ shape miopen_quant_convolution::compile(context& ctx, ...@@ -49,8 +49,8 @@ shape miopen_quant_convolution::compile(context& ctx,
std::vector<shape> inputs) std::vector<shape> inputs)
{ {
shape workspace_shape{}; shape workspace_shape{};
auto x_desc = make_tensor(inputs[0], true); auto x_desc = make_tensor(inputs[0], int8_x4_format);
auto w_desc = make_tensor(inputs[1], true); auto w_desc = make_tensor(inputs[1], int8_x4_format);
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0; std::size_t workspace_size = 0;
...@@ -62,8 +62,15 @@ shape miopen_quant_convolution::compile(context& ctx, ...@@ -62,8 +62,15 @@ shape miopen_quant_convolution::compile(context& ctx,
&workspace_size); &workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}}; workspace_shape = shape{shape::int8_type, {workspace_size}};
auto arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0]))); auto x_shape = inputs[0];
auto arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1]))); auto w_shape = inputs[1];
if(int8_x4_format)
{
x_shape = pack_int8_shape(x_shape);
w_shape = pack_int8_shape(w_shape);
}
auto arg_vec4_x = to_gpu(generate_argument(x_shape));
auto arg_vec4_w = to_gpu(generate_argument(w_shape));
auto y = allocate_gpu(output_shape); auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape); auto workspace = allocate_gpu(workspace_shape);
......
...@@ -77,28 +77,28 @@ MIGRAPHX_REGISTER_OP(wait_event) ...@@ -77,28 +77,28 @@ MIGRAPHX_REGISTER_OP(wait_event)
MIGRAPHX_REGISTER_OP(set_stream) MIGRAPHX_REGISTER_OP(set_stream)
std::size_t schedule_model::concurrency() const { return streams; } std::size_t schedule_model::concurrency() const { return streams; }
void schedule_model::sched(module& p, instruction_ref ins, std::size_t n) const void schedule_model::sched(module& m, instruction_ref ins, std::size_t n) const
{ {
auto last_stream = std::find_if(std::make_reverse_iterator(ins), auto last_stream = std::find_if(std::make_reverse_iterator(ins),
std::make_reverse_iterator(p.begin()), std::make_reverse_iterator(m.begin()),
[&](auto&& i) { return i.name() == "gpu::set_stream"; }); [&](auto&& i) { return i.name() == "gpu::set_stream"; });
if(last_stream != std::make_reverse_iterator(p.begin())) if(last_stream != std::make_reverse_iterator(m.begin()))
{ {
auto&& op = any_cast<set_stream>(last_stream->get_operator()); auto&& op = any_cast<set_stream>(last_stream->get_operator());
// If the same stream was set earlier then skip // If the same stream was set earlier then skip
if(op.stream == n) if(op.stream == n)
return; return;
} }
p.insert_instruction(ins, set_stream{n}); m.insert_instruction(ins, set_stream{n});
} }
void schedule_model::wait(module& p, instruction_ref ins, std::size_t wait_id) const void schedule_model::wait(module& m, instruction_ref ins, std::size_t wait_id) const
{ {
p.insert_instruction(ins, wait_event{wait_id}); m.insert_instruction(ins, wait_event{wait_id});
} }
void schedule_model::record(module& p, instruction_ref ins, std::size_t wait_id) const void schedule_model::record(module& m, instruction_ref ins, std::size_t wait_id) const
{ {
p.insert_instruction(std::next(ins), record_event{wait_id}); m.insert_instruction(std::next(ins), record_event{wait_id});
} }
static std::unordered_map<std::string, std::size_t> create_weight_map() static std::unordered_map<std::string, std::size_t> create_weight_map()
......
...@@ -8,9 +8,9 @@ namespace migraphx { ...@@ -8,9 +8,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void sync_device::apply(module& p) const void sync_device::apply(module& m) const
{ {
auto last = std::prev(p.end()); auto last = std::prev(m.end());
if(last->name() == "@return") if(last->name() == "@return")
{ {
auto inputs = last->inputs(); auto inputs = last->inputs();
...@@ -18,10 +18,10 @@ void sync_device::apply(module& p) const ...@@ -18,10 +18,10 @@ void sync_device::apply(module& p) const
return (i->name() == "hip::copy_from_gpu"); return (i->name() == "hip::copy_from_gpu");
})) }))
{ {
auto sync_in = p.insert_instruction(last, make_op("hip::sync_stream"), inputs); auto sync_in = m.insert_instruction(last, make_op("hip::sync_stream"), inputs);
if(not inputs.empty()) if(not inputs.empty())
{ {
p.replace_instruction(inputs.front(), sync_in); m.replace_instruction(inputs.front(), sync_in);
} }
} }
} }
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <migraphx/gpu/eliminate_workspace.hpp> #include <migraphx/gpu/eliminate_workspace.hpp>
#include <migraphx/gpu/fuse_mlir.hpp> #include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/pack_int8_args.hpp> #include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/schedule_model.hpp> #include <migraphx/gpu/schedule_model.hpp>
...@@ -96,6 +97,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -96,6 +97,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_algebra{}, simplify_algebra{},
simplify_reshapes{}, simplify_reshapes{},
simplify_algebra{}, simplify_algebra{},
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
......
...@@ -11,25 +11,25 @@ namespace gpu { ...@@ -11,25 +11,25 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS)
void write_literals::apply(module& p) const void write_literals::apply(module& m) const
{ {
assert(ctx != nullptr); assert(ctx != nullptr);
std::size_t n = 0; std::size_t n = 0;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
if(enabled(MIGRAPHX_COPY_LITERALS{})) if(enabled(MIGRAPHX_COPY_LITERALS{}))
{ {
literal l = ins->get_literal(); literal l = ins->get_literal();
auto pre = p.add_literal(l); auto pre = m.add_literal(l);
auto alloc = p.insert_instruction(std::next(pre), hip_allocate{l.get_shape()}); auto alloc = m.insert_instruction(std::next(pre), hip_allocate{l.get_shape()});
p.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc); m.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc);
} }
else else
{ {
std::string id = p.name() + ":@literal:" + std::to_string(n); std::string id = m.name() + ":@literal:" + std::to_string(n);
p.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id}); m.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id});
n++; n++;
} }
} }
......
#include <migraphx/ref/gemm.hpp> #include <migraphx/ref/gemm.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/par_for.hpp>
#include <blaze/math/CustomMatrix.h> #include <blaze/math/CustomMatrix.h>
namespace migraphx { namespace migraphx {
...@@ -74,8 +74,10 @@ void migemm_impl( ...@@ -74,8 +74,10 @@ void migemm_impl(
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]); assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
auto cs = cmat.get_shape();
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx; auto a_idx = c_idx;
auto b_idx = c_idx; auto b_idx = c_idx;
double s = 0.0; double s = 0.0;
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <migraphx/op/loop.hpp> #include <migraphx/op/loop.hpp>
#include <migraphx/op/lrn.hpp> #include <migraphx/op/lrn.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/softmax.hpp> #include <migraphx/op/softmax.hpp>
#include <migraphx/op/argmax.hpp> #include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp> #include <migraphx/op/argmin.hpp>
...@@ -335,109 +334,6 @@ struct ref_im2col ...@@ -335,109 +334,6 @@ struct ref_im2col
}; };
MIGRAPHX_REGISTER_OP(ref_im2col) MIGRAPHX_REGISTER_OP(ref_im2col)
struct max_pool
{
static std::string name() { return "max"; }
template <class T>
static T start()
{
return std::numeric_limits<T>::lowest();
}
static double apply(double x, double y)
{
double m = std::max(x, y);
return (m);
}
static double final(double x, std::size_t) { return (x); }
};
struct avg_pool
{
static std::string name() { return "average"; }
template <class T>
static double start()
{
return 0.0;
}
static double apply(double x, double y) { return x + y; }
static double final(double x, std::size_t y) { return (y == 0) ? 0.0 : (x / y); }
};
template <class Op>
struct ref_pooling : auto_register_op<ref_pooling<Op>>
{
ref_pooling() = default;
ref_pooling(op::pooling pop) : op(std::move(pop)) {}
op::pooling op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::pooling_" + Op::name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
auto in_s = input.get_shape();
auto in_lens = in_s.lens();
std::vector<std::size_t> vec_len(in_lens.begin() + 2, in_lens.end());
par_for(output_shape.elements(), [&](auto i) {
auto idx_o = output_shape.multi(i);
auto n_dim = idx_o.size();
std::vector<std::size_t> win_start;
std::vector<std::size_t> win_size;
for(std::size_t dim = 2; dim < n_dim; ++dim)
{
auto d_2 = dim - 2;
int start = static_cast<int>(idx_o[dim] * op.stride[d_2]) -
static_cast<int>(op.padding[d_2]);
int end = std::min(start + op.lengths[d_2], in_lens[dim]);
start = std::max(start, 0);
win_start.push_back(start);
win_size.push_back(end - start);
}
shape win_shape{output_shape.type(), win_size};
auto pool_size = win_shape.elements();
double acc = Op::template start<type>();
shape_for_each(win_shape, [&](auto idx_w) {
auto idx = idx_o;
std::transform(idx_w.begin(),
idx_w.end(),
win_start.begin(),
idx.begin() + 2,
[](auto ii, auto jj) { return ii + jj; });
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
idx < in_lens)
{
acc = Op::apply(acc, input[in_s.index(idx)]);
}
});
output[i] = type(Op::final(acc, pool_size));
});
});
return result;
}
};
struct ref_op struct ref_op
{ {
operation op = op::identity{}; operation op = op::identity{};
...@@ -609,7 +505,7 @@ struct ref_unary : auto_register_op<ref_unary<Op>> ...@@ -609,7 +505,7 @@ struct ref_unary : auto_register_op<ref_unary<Op>>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0); const auto& s = inputs.at(0);
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
...@@ -783,11 +679,7 @@ struct ref_apply ...@@ -783,11 +679,7 @@ struct ref_apply
init(); init();
for(auto it : iterator_for(*mod)) for(auto it : iterator_for(*mod))
{ {
if(it->name() == "pooling") if(apply_map.count(it->name()) > 0)
{
apply_pooling(it);
}
else if(apply_map.count(it->name()) > 0)
{ {
apply_map.at(it->name())(it); apply_map.at(it->name())(it);
} }
...@@ -815,15 +707,6 @@ struct ref_apply ...@@ -815,15 +707,6 @@ struct ref_apply
auto&& op = any_cast<Op>(ins->get_operator()); auto&& op = any_cast<Op>(ins->get_operator());
mod->replace_instruction(ins, T{op}, ins->inputs()); mod->replace_instruction(ins, T{op}, ins->inputs());
} }
void apply_pooling(instruction_ref ins) const
{
auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == op::pooling_mode::max)
mod->replace_instruction(ins, ref_pooling<max_pool>{op}, ins->inputs());
else if(op.mode == op::pooling_mode::average)
mod->replace_instruction(ins, ref_pooling<avg_pool>{op}, ins->inputs());
}
}; };
void lowering::apply(module& m) const { ref_apply{&m}.apply(); } void lowering::apply(module& m) const { ref_apply{&m}.apply(); }
......
function(add_api_test TEST_NAME TEST_SRC TEST_DIR) function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME}) set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC}) add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
...@@ -10,6 +9,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR) ...@@ -10,6 +9,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
add_dependencies(check ${NAME}) add_dependencies(check ${NAME})
endfunction() endfunction()
add_api_test(array_base test_array_base.cpp ${TEST_ONNX_DIR})
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR}) add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
...@@ -19,7 +19,8 @@ add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) ...@@ -19,7 +19,8 @@ add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR}) add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR})
# GPU-based tests
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR}) add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR})
# GPU-based tests target_link_libraries(test_api_gpu migraphx_gpu)
endif() endif()
#include <migraphx/migraphx.hpp>
#include "test.hpp"
struct array2 : migraphx::array_base<array2>
{
std::vector<int> v;
array2() = default;
array2(std::initializer_list<int> x) : v(x) {}
std::size_t size() const { return v.size(); }
int operator[](std::size_t i) const { return v[i]; }
};
TEST_CASE(iterators)
{
array2 a = {1, 2, 3};
EXPECT(bool{std::equal(a.begin(), a.end(), a.v.begin())});
}
TEST_CASE(front_back)
{
array2 a = {1, 2, 3};
EXPECT(a.front() == 1);
EXPECT(a.back() == 3);
}
TEST_CASE(empty)
{
array2 a = {1, 2, 3};
EXPECT(not a.empty());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <numeric> #include <numeric>
#include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> #include <migraphx/migraphx.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -38,6 +39,7 @@ TEST_CASE(load_and_run_ctx) ...@@ -38,6 +39,7 @@ TEST_CASE(load_and_run_ctx)
pp.add(name, migraphx::argument::generate(param_shapes[name])); pp.add(name, migraphx::argument::generate(param_shapes[name]));
} }
auto ctx = p.experimental_get_context(); auto ctx = p.experimental_get_context();
EXPECT(ctx.get_queue<hipStream_t>() != nullptr);
p.eval(pp); p.eval(pp);
ctx.finish(); ctx.finish();
} }
......
...@@ -3,23 +3,21 @@ ...@@ -3,23 +3,21 @@
#include <migraphx/migraphx.hpp> #include <migraphx/migraphx.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(add_op) TEST_CASE(add_literals)
{ {
migraphx::program p; migraphx::program p;
migraphx::module m = p.get_main_module(); migraphx::module m = p.get_main_module();
migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}}; migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}};
auto x = m.add_parameter("x", param_shape); std::vector<float> x_values(9, 1);
auto y = m.add_parameter("y", param_shape); auto x = m.add_literal(param_shape, x_values.data());
std::vector<float> y_values(9, -1);
auto y = m.add_literal(param_shape, y_values.data());
auto add_op = migraphx::operation("add"); auto add_op = migraphx::operation("add");
auto r = m.add_instruction(add_op, {x, y}); auto r = m.add_instruction(add_op, {x, y});
m.add_return({r}); m.add_return({r});
// run on ref target // run on ref target
p.compile(migraphx::target("ref")); p.compile(migraphx::target("ref"));
migraphx::program_parameters pp; migraphx::program_parameters pp;
std::vector<float> x_data(9, 1);
std::vector<float> y_data(9, -1);
pp.add("x", migraphx::argument(param_shape, x_data.data()));
pp.add("y", migraphx::argument(param_shape, y_data.data()));
auto outputs = p.eval(pp); auto outputs = p.eval(pp);
auto output = outputs[0]; auto output = outputs[0];
std::vector<float> expected(9, 0); std::vector<float> expected(9, 0);
...@@ -60,16 +58,16 @@ TEST_CASE(if_then_else_op) ...@@ -60,16 +58,16 @@ TEST_CASE(if_then_else_op)
p.compile(migraphx::target("ref")); p.compile(migraphx::target("ref"));
auto outputs = auto outputs =
p.eval({{"cond", migraphx::argument(cond_s, &cond)}, {"x", x_arg}, {"y", y_arg}}); p.eval({{"cond", migraphx::argument(cond_s, &cond)}, {"x", x_arg}, {"y", y_arg}});
return outputs; return outputs[0];
}; };
// then branch // then branch
auto then_res = run_prog(true); auto then_res = run_prog(true);
CHECK(bool{then_res[0] == x_arg}); CHECK(bool{then_res == x_arg});
// else branch // else branch
auto else_res = run_prog(false); auto else_res = run_prog(false);
CHECK(bool{else_res[0] == y_arg}); CHECK(bool{else_res == y_arg});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment