Unverified Commit dc8edad8 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into check-mlir-perf

parents b9cbfd8e 32f0b028
......@@ -29,11 +29,15 @@
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/float8.hpp>
namespace migraphx {
namespace math {
constexpr float as_float(migraphx::half x) { return x; }
constexpr float as_float(migraphx::fp8::fp8e4m3fnuz x) { return x; }
template <class T>
constexpr T as_float(T x)
{
......@@ -57,14 +61,14 @@ constexpr T as_float(T x)
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type \
auto __device__ name(type x, Ts... xs) -> type \
{ \
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \
inline auto __device__ name(type x, type y)->type { return fname(x, y); }
inline auto __device__ name(type x, type y) -> type { return fname(x, y); }
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
......@@ -72,6 +76,12 @@ constexpr T as_float(T x)
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8::fp8e4m3fnuz x, Ts... xs) MIGRAPHX_RETURNS( \
migraphx::fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// Template with two overloads for math functions, one for half2 type and one for more generic
// <half, N> vectorization where N is 4 or another even number.
......@@ -162,6 +172,33 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// use float to compute fp8 overload
MIGRAPHX_DEVICE_MATH_FP8(abs, ::abs)
MIGRAPHX_DEVICE_MATH_FP8(acos, ::acos)
MIGRAPHX_DEVICE_MATH_FP8(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH_FP8(asin, ::asin)
MIGRAPHX_DEVICE_MATH_FP8(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_FP8(atan, ::atan)
MIGRAPHX_DEVICE_MATH_FP8(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_FP8(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_FP8(cos, ::cos)
MIGRAPHX_DEVICE_MATH_FP8(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_FP8(erf, ::erf)
MIGRAPHX_DEVICE_MATH_FP8(exp, ::exp)
MIGRAPHX_DEVICE_MATH_FP8(floor, ::floor)
MIGRAPHX_DEVICE_MATH_FP8(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_FP8(log, ::log)
MIGRAPHX_DEVICE_MATH_FP8(pow, ::pow)
MIGRAPHX_DEVICE_MATH_FP8(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_FP8(round, ::round)
MIGRAPHX_DEVICE_MATH_FP8(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH_FP8(sin, ::sin)
MIGRAPHX_DEVICE_MATH_FP8(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_FP8(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH_FP8(tan, ::tan)
MIGRAPHX_DEVICE_MATH_FP8(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_FP8(fmod, ::fmod)
// Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
......@@ -253,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U>
constexpr auto convert(U v)
{
return vec_transform(v)([](auto x) -> T { return x; });
return vec_transform(v)([](auto x) -> T { return static_cast<T>(x); });
}
} // namespace migraphx
......
......@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx {
......@@ -53,9 +54,9 @@ __device__ void pad(const index& idx,
if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) {
return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j];
}))
output[multi] = pad_val;
output[multi] = implicit_conversion(pad_val);
else
output[multi] = input[input_idx];
output[multi] = implicit_conversion(input[input_idx]);
});
}
......
......@@ -106,7 +106,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
#endif
using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init;
type x = type(init);
idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op);
......@@ -117,7 +117,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
}
__syncthreads();
type y = init;
type y = type(init);
for(index_int i = 0; i < idx.nlocal() / lanes_per_thread; i++)
{
y = op(y, buffer[i]);
......@@ -244,9 +244,8 @@ struct reducer_base
{
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
return make_storage_access<typename decltype(t)::type>([=](auto i, auto...) -> auto& {
return t[i];
});
return make_storage_access<typename decltype(t)::type>(
[=](auto i, auto...) -> auto& { return t[i]; });
}
}
......@@ -393,7 +392,7 @@ struct block
{
using max_iterations = decltype(idx.max_local_stride_iterations(n));
inner_storage<R, max_iterations{}, N> storage;
idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = R{f(xs(j, d)...)}; });
return storage;
}
};
......@@ -482,7 +481,7 @@ struct lane
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{
using type = remove_reference_t<decltype(x(0, _c<0>))>;
type r = init;
type r = type(init);
for(index_int j = 0; j < n; j++)
{
r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...));
......
......@@ -62,7 +62,7 @@ struct avg_pool
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y)
{
return (y == 0) ? 0.0 : (x / y);
return (y == 0) ? T{0.0} : T{x / y};
}
};
......@@ -76,7 +76,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
{
if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{
return 0;
return implicit_conversion(0);
}
xy[ii] = migraphx::max(xy[ii], 0.0f);
......@@ -92,15 +92,16 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0];
float lx = xy[1] - low[1];
float hy = 1.0f - ly;
float hx = 1.0f - lx;
array<typename Iterator::value_type, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
float ly = xy[0] - low[0];
float lx = xy[1] - low[1];
float hy = 1.0f - ly;
float hx = 1.0f - lx;
// do calculations in floating point and convert final result to required type
array<float, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
return pooling(v01, v23);
return implicit_conversion(pooling(v01, v23));
}
template <class Iterator, class Op>
......@@ -113,8 +114,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data,
float roi_offset,
Op op)
{
typename Iterator::value_type output_val = op.init();
const int64_t count = bin_grid_size[0] * bin_grid_size[1];
using in_dtype = typename Iterator::value_type;
in_dtype output_val = in_dtype{op.init()};
const int64_t count = bin_grid_size[0] * bin_grid_size[1];
dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) {
array<index_int, 2> id = {iy, ix};
array<float, 2> locs =
......@@ -148,7 +150,6 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto x = x_t.begin();
const auto rois = rois_t.begin();
const auto ind = ind_t.begin();
// input shape
auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1];
......@@ -176,10 +177,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale,
offset_rois[0] * s.spatial_scale};
array<float, 2> roi_ends = {offset_rois[3] * s.spatial_scale,
offset_rois[2] * s.spatial_scale};
array<float, 2> roi_starts = {
static_cast<float>(offset_rois[1]) * static_cast<float>(s.spatial_scale),
static_cast<float>(offset_rois[0]) * static_cast<float>(s.spatial_scale)};
array<float, 2> roi_ends = {
static_cast<float>(offset_rois[3]) * static_cast<float>(s.spatial_scale),
static_cast<float>(offset_rois[2]) * static_cast<float>(s.spatial_scale)};
array<float, 2> roi_size{};
array<float, 2> bin_size{};
......
......@@ -43,7 +43,7 @@ __device__ void softmax(Input input1, Output output)
auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
auto batch_sum =
r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = x / batch_sum; })(output, exp_in);
r.inner([&](auto& y, auto x) { y = implicit_conversion(x / batch_sum); })(output, exp_in);
});
}
......
......@@ -27,6 +27,7 @@
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/iota_iterator.hpp>
#include <migraphx/kernels/float8.hpp>
namespace migraphx {
......
......@@ -251,7 +251,7 @@ constexpr T numeric_max()
}
template <class T>
constexpr T numeric_lowest()
constexpr auto numeric_lowest() -> decltype(numeric_max<T>())
{
if constexpr(is_integral<T>{})
{
......
......@@ -207,7 +207,7 @@ struct implicit_conversion_op
template <class U>
constexpr operator U() const
{
return x;
return static_cast<U>(x);
}
};
......
......@@ -73,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_LIMIT);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG);
......@@ -796,7 +797,9 @@ struct mlir_program
if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{}))
tuning_mode = RocmlirTuningParamSetKindExhaustive;
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)};
for(auto i : range(mlirRockTuningGetNumParams(params.get())))
const auto limit =
value_of(MIGRAPHX_MLIR_TUNE_LIMIT{}, std::numeric_limits<std::size_t>::max());
for(auto i : range(std::min<std::size_t>(limit, mlirRockTuningGetNumParams(params.get()))))
{
mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get()))
......@@ -1032,6 +1035,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
mlir_program mp;
mp.set_gpu_properties(migraphx_ctx);
mp.parse(m);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
return mp.get_tuning_config(exhaustive);
}
......
......@@ -31,6 +31,7 @@
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp>
#endif
#include <migraphx/gpu/fuse_mlir.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -124,34 +125,55 @@ struct find_add_layernorm
}
};
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
struct pre_gemm_softmax_gemm : gemm_softmax_gemm
{
std::string name() const { return "gpu::pre_gemm_softmax_gemm"; }
};
MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm);
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
auto is_ck_gemm()
{
if(ins->name() != "dot")
return false;
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type()))
return match::make_basic_pred_matcher([=](instruction_ref ins) {
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
if(not enabled(MIGRAPHX_ENABLE_CK{}))
return false;
if(ins->name() != "dot")
return false;
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
return true;
#else
(void)ins;
return false;
return true;
#endif
});
}
auto is_mlir_gemm()
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(not mlir_attention_enabled())
return false;
if(ins->name() != "dot")
return false;
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return pre_gemm_softmax_gemm::is_mlir_supported_type(i->get_shape().type());
});
});
}
struct find_gemm_softmax_gemm
{
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto gemm1 = match::skip(match::name("contiguous"))(
match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1")));
auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax));
return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))(
match::arg(0)(softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
......@@ -179,8 +201,6 @@ struct find_gemm_softmax_gemm
}
};
#endif
} // namespace
void prefuse_ops::apply(module_pass_manager& mpm) const
......@@ -188,10 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{});
#ifdef MIHRAPHX_USE_COMPOSABLEKERNEL
if(enabled(MIGRAPHX_ENABLE_CK{}))
match::find_matches(mpm, find_gemm_softmax_gemm{});
#endif
match::find_matches(mpm, find_gemm_softmax_gemm{});
}
} // namespace gpu
......
......@@ -98,6 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
ctx.set_exhaustive_tune_flag(options.exhaustive_tune);
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type);
unsupported_types.erase(shape::type_t::fp8e4m3fnuz_type);
unsupported_types.erase(shape::type_t::half_type);
unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type);
......
......@@ -144,10 +144,12 @@ TEST_CASE(int_quant_dot_tanh_fails)
auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh"));
mm->add_return({tanh});
}
migraphx::program p2(p1);
// This pass should do nothing as int32_t tanh isn't supported.
// This pass should not fuse as int32_t tanh isn't supported.
run_pass(p1);
EXPECT(p1 == p2);
auto* mm = p1.get_main_module();
bool has_pointwise =
std::any_of(mm->begin(), mm->end(), [&](const auto& i) { return i.name() == "pointwise"; });
EXPECT(has_pointwise);
}
int main(int argc, const char* argv[])
......
......@@ -350,18 +350,19 @@ TEST_CASE(compile_math)
auto vec_sizes = {2, 4, 6};
for(auto&& t : migraphx::shape::types())
{
if(contains({migraphx::shape::bool_type,
migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::tuple_type},
t))
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue;
auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::");
data_types.push_back(name);
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
});
// fp8 doesn't have vectorization support yet, therefore skip it for now.
if(t != migraphx::shape::fp8e4m3fnuz_type)
{
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
});
}
}
migraphx::shape input{migraphx::shape::float_type, {5, 2}};
migraphx::gpu::hip_compile_options options;
......@@ -431,7 +432,6 @@ TEST_CASE(assert_type_min_max)
min = std::to_string(as.min());
max = std::to_string(as.max());
}
auto src = migraphx::interpolate_string(assert_template,
{{"type", name}, {"max", max}, {"min", min}});
migraphx::shape input{migraphx::shape::float_type, {5, 2}};
......
a5537f2f563d4975c7e6121a7eb260bbbfd9455a
d69842226b47e5336568103541b071447caeb9bf
......@@ -9543,6 +9543,97 @@ def undefined_test():
return ([node], [x], [y])
@onnx_test()
def unique_dynamic_sorted_test():
x = helper.make_tensor_value_info('X', TensorProto.FLOAT, [6])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [4])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[6])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
axis=0,
sorted=1)
return ([node], [x], [y, y_ind, x_ind, count])
@onnx_test()
def unique_dynamic_sorted_3D_test():
x = helper.make_tensor_value_info('X', TensorProto.INT64, [4, 4, 4])
y = helper.make_tensor_value_info('Y', TensorProto.INT64, [16])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [16])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[64])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [16])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
sorted=1)
return ([node], [x], [y, y_ind, x_ind, count])
@onnx_test()
def unique_dynamic_unsorted_test():
x = helper.make_tensor_value_info('X', TensorProto.FLOAT, [6])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [4])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[6])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
axis=0,
sorted=0)
return ([node], [x], [y, y_ind, x_ind, count])
@onnx_test()
def unique_sorted_test():
x = helper.make_tensor('X', TensorProto.FLOAT, [6], [2, 1, 1, 3, 4, 3])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [4])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[6])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
axis=0,
sorted=1)
return ([node], [], [y, y_ind, x_ind, count], [x])
@onnx_test()
def unique_unsorted_test():
x = helper.make_tensor('X', TensorProto.FLOAT, [6], [2, 1, 1, 3, 4, 3])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [4])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[6])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
axis=0,
sorted=0)
return ([node], [], [y, y_ind, x_ind, count], [x])
@onnx_test()
def unknown_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
......
......@@ -4826,8 +4826,9 @@ TEST_CASE(multinomial_test)
migraphx::shape s{migraphx::shape::float_type, {1}};
std::vector<float> seed_data = {seed};
auto seed_input = mm->add_literal(migraphx::literal(s, seed_data));
auto rand_dummy =
mm->add_literal(migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}});
auto rand_dummy = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
std::vector<float>(batch_size * sample_size)});
auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms);
......@@ -4978,8 +4979,9 @@ TEST_CASE(multinomial_int64_test)
auto seed_input = mm->add_literal(migraphx::literal(s, data));
// static size
auto rand_dummy =
mm->add_literal(migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}});
auto rand_dummy = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
std::vector<float>(batch_size * sample_size)});
auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
mm->add_instruction(migraphx::make_op("multinomial", {{"dtype", dtype}}), cdf, randoms);
auto prog = optimize_onnx("multinomial_int64_test.onnx");
......@@ -8604,6 +8606,86 @@ TEST_CASE(undefined_test)
EXPECT(p == prog);
}
TEST_CASE(unique_dynamic_sorted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = mm->add_parameter("X", s);
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 1}, {"axis", 0}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_ind, x_ind, count});
auto prog = migraphx::parse_onnx("unique_dynamic_sorted_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unique_dynamic_sorted_3D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int64_type, {4, 4, 4}};
auto x = mm->add_parameter("X", s);
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 1}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_ind, x_ind, count});
auto prog = migraphx::parse_onnx("unique_dynamic_sorted_3D_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unique_sorted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_x{migraphx::shape::float_type, {6}};
std::vector<float> x_data = {2, 1, 1, 3, 4, 3};
auto x = mm->add_literal(migraphx::literal(s_x, x_data));
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 1}, {"axis", 0}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_idx, x_idx, count});
auto prog = migraphx::parse_onnx("unique_sorted_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unique_unsorted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_x{migraphx::shape::float_type, {6}};
std::vector<float> x_data = {2, 1, 1, 3, 4, 3};
auto x = mm->add_literal(migraphx::literal(s_x, x_data));
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 0}, {"axis", 0}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_idx, x_idx, count});
auto prog = migraphx::parse_onnx("unique_unsorted_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unknown_test)
{
migraphx::program p;
......
 unique_dynamic_sorted_3D_test:Ö
?
XYindicesinverse_indicescounts"Unique*
sorted unique_dynamic_sorted_3D_testZ
X



b
Y

b
indices

b
inverse_indices

@b
counts

B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment