Commit d7dfe995 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into auto_contig_fix

parents c6ec6638 e3e00547
...@@ -106,7 +106,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) ...@@ -106,7 +106,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
#endif #endif
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = type(init);
idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op); dpp_reduce(x, op);
...@@ -117,7 +117,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) ...@@ -117,7 +117,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
} }
__syncthreads(); __syncthreads();
type y = init; type y = type(init);
for(index_int i = 0; i < idx.nlocal() / lanes_per_thread; i++) for(index_int i = 0; i < idx.nlocal() / lanes_per_thread; i++)
{ {
y = op(y, buffer[i]); y = op(y, buffer[i]);
...@@ -244,9 +244,8 @@ struct reducer_base ...@@ -244,9 +244,8 @@ struct reducer_base
{ {
auto&& derived = static_cast<const Derived&>(*this); auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x); auto t = derived.slice(x);
return make_storage_access<typename decltype(t)::type>([=](auto i, auto...) -> auto& { return make_storage_access<typename decltype(t)::type>(
return t[i]; [=](auto i, auto...) -> auto& { return t[i]; });
});
} }
} }
...@@ -393,7 +392,7 @@ struct block ...@@ -393,7 +392,7 @@ struct block
{ {
using max_iterations = decltype(idx.max_local_stride_iterations(n)); using max_iterations = decltype(idx.max_local_stride_iterations(n));
inner_storage<R, max_iterations{}, N> storage; inner_storage<R, max_iterations{}, N> storage;
idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); }); idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = R{f(xs(j, d)...)}; });
return storage; return storage;
} }
}; };
...@@ -482,7 +481,7 @@ struct lane ...@@ -482,7 +481,7 @@ struct lane
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const __device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{ {
using type = remove_reference_t<decltype(x(0, _c<0>))>; using type = remove_reference_t<decltype(x(0, _c<0>))>;
type r = init; type r = type(init);
for(index_int j = 0; j < n; j++) for(index_int j = 0; j < n; j++)
{ {
r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...)); r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...));
......
...@@ -62,7 +62,7 @@ struct avg_pool ...@@ -62,7 +62,7 @@ struct avg_pool
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y) MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y)
{ {
return (y == 0) ? 0.0 : (x / y); return (y == 0) ? T{0.0} : T{x / y};
} }
}; };
...@@ -76,7 +76,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( ...@@ -76,7 +76,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
{ {
if(xy[ii] < -1.0f or xy[ii] > dims[ii]) if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{ {
return 0; return implicit_conversion(0);
} }
xy[ii] = migraphx::max(xy[ii], 0.0f); xy[ii] = migraphx::max(xy[ii], 0.0f);
...@@ -92,15 +92,16 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( ...@@ -92,15 +92,16 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
high[0] * dims[1] + low[1], high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]}; high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0]; float ly = xy[0] - low[0];
float lx = xy[1] - low[1]; float lx = xy[1] - low[1];
float hy = 1.0f - ly; float hy = 1.0f - ly;
float hx = 1.0f - lx; float hx = 1.0f - lx;
array<typename Iterator::value_type, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx}; // do calculations in floating point and convert final result to required type
array<float, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]); auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]); auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
return pooling(v01, v23); return implicit_conversion(pooling(v01, v23));
} }
template <class Iterator, class Op> template <class Iterator, class Op>
...@@ -113,8 +114,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data, ...@@ -113,8 +114,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data,
float roi_offset, float roi_offset,
Op op) Op op)
{ {
typename Iterator::value_type output_val = op.init(); using in_dtype = typename Iterator::value_type;
const int64_t count = bin_grid_size[0] * bin_grid_size[1]; in_dtype output_val = in_dtype{op.init()};
const int64_t count = bin_grid_size[0] * bin_grid_size[1];
dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) { dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) {
array<index_int, 2> id = {iy, ix}; array<index_int, 2> id = {iy, ix};
array<float, 2> locs = array<float, 2> locs =
...@@ -148,7 +150,6 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, ...@@ -148,7 +150,6 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto x = x_t.begin(); const auto x = x_t.begin();
const auto rois = rois_t.begin(); const auto rois = rois_t.begin();
const auto ind = ind_t.begin(); const auto ind = ind_t.begin();
// input shape // input shape
auto x_lens = x_t.get_shape().lens; auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1]; auto channel_num = x_lens[1];
...@@ -176,10 +177,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, ...@@ -176,10 +177,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
const auto offset_rois = rois + (n * roi_column_num); const auto offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n]; const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale, array<float, 2> roi_starts = {
offset_rois[0] * s.spatial_scale}; static_cast<float>(offset_rois[1]) * static_cast<float>(s.spatial_scale),
array<float, 2> roi_ends = {offset_rois[3] * s.spatial_scale, static_cast<float>(offset_rois[0]) * static_cast<float>(s.spatial_scale)};
offset_rois[2] * s.spatial_scale}; array<float, 2> roi_ends = {
static_cast<float>(offset_rois[3]) * static_cast<float>(s.spatial_scale),
static_cast<float>(offset_rois[2]) * static_cast<float>(s.spatial_scale)};
array<float, 2> roi_size{}; array<float, 2> roi_size{};
array<float, 2> bin_size{}; array<float, 2> bin_size{};
......
...@@ -43,7 +43,7 @@ __device__ void softmax(Input input1, Output output) ...@@ -43,7 +43,7 @@ __device__ void softmax(Input input1, Output output)
auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input); auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
auto batch_sum = auto batch_sum =
r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in); r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = x / batch_sum; })(output, exp_in); r.inner([&](auto& y, auto x) { y = implicit_conversion(x / batch_sum); })(output, exp_in);
}); });
} }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/kernels/shape.hpp> #include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/iota_iterator.hpp> #include <migraphx/kernels/iota_iterator.hpp>
#include <migraphx/kernels/float8.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -251,7 +251,7 @@ constexpr T numeric_max() ...@@ -251,7 +251,7 @@ constexpr T numeric_max()
} }
template <class T> template <class T>
constexpr T numeric_lowest() constexpr auto numeric_lowest() -> decltype(numeric_max<T>())
{ {
if constexpr(is_integral<T>{}) if constexpr(is_integral<T>{})
{ {
......
...@@ -207,7 +207,7 @@ struct implicit_conversion_op ...@@ -207,7 +207,7 @@ struct implicit_conversion_op
template <class U> template <class U>
constexpr operator U() const constexpr operator U() const
{ {
return x; return static_cast<U>(x);
} }
}; };
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
#include <mlir-c/Pass.h> #include <mlir-c/Pass.h>
#include <mlir-c/Support.h> #include <mlir-c/Support.h>
#include <mutex> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3 #if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 4
#warning "Incompatible version of rocMLIR library used, disabling" #warning "Incompatible version of rocMLIR library used, disabling"
// Only undefine when not using cppcheck // Only undefine when not using cppcheck
#ifndef CPPCHECK #ifndef CPPCHECK
...@@ -73,6 +73,7 @@ namespace gpu { ...@@ -73,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_LIMIT);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG);
...@@ -318,31 +319,30 @@ struct mlir_program ...@@ -318,31 +319,30 @@ struct mlir_program
return result; return result;
} }
MlirType make_tensor(const shape& s) const MlirType make_mlir_shaped(const shape& s) const
{ {
if(not s.standard())
MIGRAPHX_THROW("MLIR expects all tensors to be in standard shape");
if(s.dynamic()) if(s.dynamic())
MIGRAPHX_THROW("MLIR does not support dynamic shapes"); MIGRAPHX_THROW("MLIR does not support dynamic shapes");
std::vector<int64_t> lens(s.lens().begin(), s.lens().end()); std::vector<int64_t> lens(s.lens().begin(), s.lens().end());
return mlirRankedTensorTypeGet( std::vector<int64_t> strides(s.strides().begin(), s.strides().end());
lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull()); return rocmlirMIXRShapedTypeGet(
lens.size(), lens.data(), strides.data(), make_type(s.type()));
} }
template <class Range> template <class Range>
std::vector<MlirType> make_tensors(const Range& r) std::vector<MlirType> make_mlir_shapeds(const Range& r)
{ {
std::vector<MlirType> result; std::vector<MlirType> result;
std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& s) { std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& s) {
return make_tensor(s); return make_mlir_shaped(s);
}); });
return result; return result;
} }
MlirType make_function_type(const std::vector<shape>& inputs, const std::vector<shape>& outputs) MlirType make_function_type(const std::vector<shape>& inputs, const std::vector<shape>& outputs)
{ {
auto in = make_tensors(inputs); auto in = make_mlir_shapeds(inputs);
auto out = make_tensors(outputs); auto out = make_mlir_shapeds(outputs);
return mlirFunctionTypeGet(ctx.get(), in.size(), in.data(), out.size(), out.data()); return mlirFunctionTypeGet(ctx.get(), in.size(), in.data(), out.size(), out.data());
} }
...@@ -504,11 +504,7 @@ struct mlir_program ...@@ -504,11 +504,7 @@ struct mlir_program
mlir_operation_state& add_results(const std::vector<shape>& outputs) mlir_operation_state& add_results(const std::vector<shape>& outputs)
{ {
std::vector<shape> reshaped(outputs.size()); auto x = prog->make_mlir_shapeds(outputs);
std::transform(outputs.begin(), outputs.end(), reshaped.begin(), [](const shape& r) {
return shape{r.type(), r.lens()};
});
auto x = prog->make_tensors(reshaped);
if(not x.empty()) if(not x.empty())
{ {
mlirOperationStateAddResults(&op_state, x.size(), x.data()); mlirOperationStateAddResults(&op_state, x.size(), x.data());
...@@ -581,7 +577,7 @@ struct mlir_program ...@@ -581,7 +577,7 @@ struct mlir_program
std::vector<shape> outputs = m.get_output_shapes(); std::vector<shape> outputs = m.get_output_shapes();
std::vector<MlirLocation> arg_locs(inputs.size(), location); std::vector<MlirLocation> arg_locs(inputs.size(), location);
auto body_inputs = make_tensors(inputs); auto body_inputs = make_mlir_shapeds(inputs);
mlir_region region = mlirRegionCreate(); mlir_region region = mlirRegionCreate();
mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data(), arg_locs.data()); mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data(), arg_locs.data());
MlirBlock result = fbody.get(); MlirBlock result = fbody.get();
...@@ -607,7 +603,7 @@ struct mlir_program ...@@ -607,7 +603,7 @@ struct mlir_program
return "func.return"; return "func.return";
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
return "tosa.const"; return "migraphx.literal";
} }
return "migraphx." + ins->name(); return "migraphx." + ins->name();
} }
...@@ -666,7 +662,8 @@ struct mlir_program ...@@ -666,7 +662,8 @@ struct mlir_program
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
literal r = ins->get_literal(); literal r = ins->get_literal();
MlirType tensor_type = make_tensor(ins->get_shape()); MlirType shaped_type = make_mlir_shaped(ins->get_shape());
MlirType tensor_type = rocmlirMIXRShapedTypeAsTensor(shaped_type);
MlirAttribute mlir_value_attr = MlirAttribute mlir_value_attr =
mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data()); mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data());
ops.add_attributes({{"value", mlir_value_attr}}); ops.add_attributes({{"value", mlir_value_attr}});
...@@ -796,7 +793,9 @@ struct mlir_program ...@@ -796,7 +793,9 @@ struct mlir_program
if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{})) if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{}))
tuning_mode = RocmlirTuningParamSetKindExhaustive; tuning_mode = RocmlirTuningParamSetKindExhaustive;
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)}; mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)};
for(auto i : range(mlirRockTuningGetNumParams(params.get()))) const auto limit =
value_of(MIGRAPHX_MLIR_TUNE_LIMIT{}, std::numeric_limits<std::size_t>::max());
for(auto i : range(std::min<std::size_t>(limit, mlirRockTuningGetNumParams(params.get()))))
{ {
mlir_tuning_param param{mlirRockTuningParamCreate()}; mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get())) if(not mlirRockTuningParamGet(params.get(), i, param.get()))
...@@ -942,35 +941,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs) ...@@ -942,35 +941,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
auto param = m.get_parameter(name); auto param = m.get_parameter(name);
if(input.standard()) if(input.standard())
continue; continue;
auto lens = input.lens(); auto new_param = m.add_parameter(name + ".0", input);
auto strides = input.strides();
std::vector<operation> ops;
if(input.transposed())
{
auto perm = find_permutation(input);
auto iperm = invert_permutation(perm);
lens = reorder_dims(lens, iperm);
strides = reorder_dims(strides, iperm);
ops.push_back(make_op("transpose", {{"permutation", perm}}));
}
if(input.broadcasted())
{
std::transform(lens.begin(),
lens.end(),
strides.begin(),
lens.begin(),
[](auto len, auto stride) -> std::size_t {
if(stride == 0)
return 1;
return len;
});
ops.push_back(make_op("multibroadcast", {{"out_lens", input.lens()}}));
}
auto new_param =
std::accumulate(ops.begin(),
ops.end(),
m.add_parameter(name + ".0", shape{input.type(), lens}),
[&](auto x, auto op) { return m.insert_instruction(param, op, x); });
m.replace_instruction(param, new_param); m.replace_instruction(param, new_param);
m.remove_instruction(param); m.remove_instruction(param);
} }
...@@ -1032,6 +1003,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx, ...@@ -1032,6 +1003,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
mlir_program mp; mlir_program mp;
mp.set_gpu_properties(migraphx_ctx); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
return mp.get_tuning_config(exhaustive); return mp.get_tuning_config(exhaustive);
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/pad.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_pad::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs, *this}.has(1).standard();
return op.compute_shape(inputs);
}
argument hip_pad::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::pad(ctx.get_stream().get(), args.back(), args.front(), op.value, op.pads);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL #ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp> #include <migraphx/gpu/ck.hpp>
#endif #endif
#include <migraphx/gpu/fuse_mlir.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -124,34 +125,55 @@ struct find_add_layernorm ...@@ -124,34 +125,55 @@ struct find_add_layernorm
} }
}; };
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
struct pre_gemm_softmax_gemm : gemm_softmax_gemm struct pre_gemm_softmax_gemm : gemm_softmax_gemm
{ {
std::string name() const { return "gpu::pre_gemm_softmax_gemm"; } std::string name() const { return "gpu::pre_gemm_softmax_gemm"; }
}; };
MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm); MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm);
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) auto is_ck_gemm()
{ {
if(ins->name() != "dot") return match::make_basic_pred_matcher([=](instruction_ref ins) {
return false; #ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type())) if(not enabled(MIGRAPHX_ENABLE_CK{}))
return false;
if(ins->name() != "dot")
return false;
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
return true;
#else
(void)ins;
return false; return false;
return true; #endif
});
}
auto is_mlir_gemm()
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(not mlir_attention_enabled())
return false;
if(ins->name() != "dot")
return false;
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return pre_gemm_softmax_gemm::is_mlir_supported_type(i->get_shape().type());
});
});
} }
struct find_gemm_softmax_gemm struct find_gemm_softmax_gemm
{ {
auto matcher() const auto matcher() const
{ {
auto gemm1 = auto gemm1 = match::skip(match::name("contiguous"))(
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm1")));
auto mul = match::name("mul")( auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1)); match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax"); auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax)); return match::name("dot")(match::any_of(is_ck_gemm(), is_mlir_gemm()).bind("gemm2"))(
match::arg(0)(softmax));
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
...@@ -179,8 +201,6 @@ struct find_gemm_softmax_gemm ...@@ -179,8 +201,6 @@ struct find_gemm_softmax_gemm
} }
}; };
#endif
} // namespace } // namespace
void prefuse_ops::apply(module_pass_manager& mpm) const void prefuse_ops::apply(module_pass_manager& mpm) const
...@@ -188,10 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const ...@@ -188,10 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match::find_matches(mpm.get_module(), find_layernorm{}); match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{}); match::find_matches(mpm.get_module(), find_add_layernorm{});
#ifdef MIHRAPHX_USE_COMPOSABLEKERNEL match::find_matches(mpm, find_gemm_softmax_gemm{});
if(enabled(MIGRAPHX_ENABLE_CK{}))
match::find_matches(mpm, find_gemm_softmax_gemm{});
#endif
} }
} // namespace gpu } // namespace gpu
......
...@@ -98,6 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -98,6 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
ctx.set_exhaustive_tune_flag(options.exhaustive_tune); ctx.set_exhaustive_tune_flag(options.exhaustive_tune);
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end()); std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type); unsupported_types.erase(shape::type_t::float_type);
unsupported_types.erase(shape::type_t::fp8e4m3fnuz_type);
unsupported_types.erase(shape::type_t::half_type); unsupported_types.erase(shape::type_t::half_type);
unsupported_types.erase(shape::type_t::bool_type); unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type); unsupported_types.erase(shape::type_t::int8_type);
......
...@@ -144,10 +144,12 @@ TEST_CASE(int_quant_dot_tanh_fails) ...@@ -144,10 +144,12 @@ TEST_CASE(int_quant_dot_tanh_fails)
auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh")); auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh"));
mm->add_return({tanh}); mm->add_return({tanh});
} }
migraphx::program p2(p1); // This pass should not fuse as int32_t tanh isn't supported.
// This pass should do nothing as int32_t tanh isn't supported.
run_pass(p1); run_pass(p1);
EXPECT(p1 == p2); auto* mm = p1.get_main_module();
bool has_pointwise =
std::any_of(mm->begin(), mm->end(), [&](const auto& i) { return i.name() == "pointwise"; });
EXPECT(has_pointwise);
} }
int main(int argc, const char* argv[]) int main(int argc, const char* argv[])
......
...@@ -350,18 +350,19 @@ TEST_CASE(compile_math) ...@@ -350,18 +350,19 @@ TEST_CASE(compile_math)
auto vec_sizes = {2, 4, 6}; auto vec_sizes = {2, 4, 6};
for(auto&& t : migraphx::shape::types()) for(auto&& t : migraphx::shape::types())
{ {
if(contains({migraphx::shape::bool_type, if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::tuple_type},
t))
continue; continue;
auto name = migraphx::shape::cpp_type(t); auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type) if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::"); name.insert(0, "migraphx::");
data_types.push_back(name); data_types.push_back(name);
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) { // fp8 doesn't have vectorization support yet, therefore skip it for now.
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">"; if(t != migraphx::shape::fp8e4m3fnuz_type)
}); {
migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) {
return "migraphx::vec<" + name + ", " + std::to_string(i) + ">";
});
}
} }
migraphx::shape input{migraphx::shape::float_type, {5, 2}}; migraphx::shape input{migraphx::shape::float_type, {5, 2}};
migraphx::gpu::hip_compile_options options; migraphx::gpu::hip_compile_options options;
...@@ -431,7 +432,6 @@ TEST_CASE(assert_type_min_max) ...@@ -431,7 +432,6 @@ TEST_CASE(assert_type_min_max)
min = std::to_string(as.min()); min = std::to_string(as.min());
max = std::to_string(as.max()); max = std::to_string(as.max());
} }
auto src = migraphx::interpolate_string(assert_template, auto src = migraphx::interpolate_string(assert_template,
{{"type", name}, {"max", max}, {"min", min}}); {{"type", name}, {"max", max}, {"min", min}});
migraphx::shape input{migraphx::shape::float_type, {5, 2}}; migraphx::shape input{migraphx::shape::float_type, {5, 2}};
......
...@@ -141,9 +141,9 @@ TEST_CASE(conv) ...@@ -141,9 +141,9 @@ TEST_CASE(conv)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_convolution(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
return %0 : tensor<1x2x2x2xf32> return %0 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -160,15 +160,38 @@ module { ...@@ -160,15 +160,38 @@ module {
EXPECT(verify_mlir(m)); EXPECT(verify_mlir(m));
} }
TEST_CASE(conv_nhwc)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x1x24x8>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x1x32x8>) -> !migraphx.shaped<1x2x2x2xf32, 8x1x4x2> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x1x32x8>, <2x8x3x3xf32, 72x1x24x8> -> <1x2x2x2xf32, 8x1x4x2>
return %0 : !migraphx.shaped<1x2x2x2xf32, 8x1x4x2>
}
}
)__migraphx__";
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}, {128, 1, 32, 8}});
auto w = m.add_parameter("w", {migraphx::shape::float_type, {2, 8, 3, 3}, {72, 1, 24, 8}});
auto conv = m.add_instruction(migraphx::make_op("convolution"), x, w);
m.add_return({conv});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(conv_add_relu) TEST_CASE(conv_add_relu)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_convolution_add_relu(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_convolution_add_relu(%arg0: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg2: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution %arg2, %arg1 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %1 = migraphx.add %0, %arg0 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %2 = migraphx.relu %1 : <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
return %2 : tensor<1x2x2x2xf32> return %2 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -192,10 +215,10 @@ TEST_CASE(quant_dot_add) ...@@ -192,10 +215,10 @@ TEST_CASE(quant_dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_quant_dot_add(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped<1x5x3xi32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32> %0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xi8, 20x4x1>, <1x4x3xi8, 12x3x1> -> <1x5x3xi32, 15x3x1>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32> %1 = migraphx.add %0, %arg2 : <1x5x3xi32, 15x3x1>, <1x5x3xi32, 15x3x1> -> <1x5x3xi32, 15x3x1>
return %1 : tensor<1x5x3xi32> return %1 : !migraphx.shaped<1x5x3xi32, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -219,10 +242,10 @@ TEST_CASE(dot_add) ...@@ -219,10 +242,10 @@ TEST_CASE(dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot_add(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_dot_add(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.add %0, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %1 : tensor<1x5x3xf32> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -245,11 +268,11 @@ TEST_CASE(conv_int8_dequantize_quantize) ...@@ -245,11 +268,11 @@ TEST_CASE(conv_int8_dequantize_quantize)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: tensor<2x8x3x3xi8>, %arg1: tensor<1x8x4x4xi8>, %arg2: tensor<1x2x2x2xf32>, %arg3: tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xi8>, tensor<2x8x3x3xi8>) -> tensor<1x2x2x2xi32> %0 = migraphx.quant_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x3x3xi8, 72x9x3x1> -> <1x2x2x2xi32, 8x4x2x1>
%1 = migraphx.dequantizelinear(%0, %arg2, %arg3) : (tensor<1x2x2x2xi32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xf32> %1 = migraphx.dequantizelinear %0, %arg2, %arg3 : <1x2x2x2xi32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
%2 = migraphx.quantizelinear(%1, %arg2, %arg3) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> %2 = migraphx.quantizelinear %1, %arg2, %arg3 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xi32, 8x4x2x1>
return %2 : tensor<1x2x2x2xi32> return %2 : !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -278,10 +301,10 @@ TEST_CASE(dot_convert) ...@@ -278,10 +301,10 @@ TEST_CASE(dot_convert)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot_convert(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_dot_convert(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped<1x5x3xf16, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16> %1 = migraphx.convert %0 {target_type = 1 : i64} : <1x5x3xf32, 15x3x1> to <1x5x3xf16, 15x3x1>
return %1 : tensor<1x5x3xf16> return %1 : !migraphx.shaped<1x5x3xf16, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
...@@ -304,10 +327,10 @@ TEST_CASE(dot_where) ...@@ -304,10 +327,10 @@ TEST_CASE(dot_where)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot_where(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %1 : tensor<1x5x3xf32> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
} }
} }
)__migraphx__"; )__migraphx__";
......
a5537f2f563d4975c7e6121a7eb260bbbfd9455a d69842226b47e5336568103541b071447caeb9bf
...@@ -9543,6 +9543,97 @@ def undefined_test(): ...@@ -9543,6 +9543,97 @@ def undefined_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def unique_dynamic_sorted_test():
x = helper.make_tensor_value_info('X', TensorProto.FLOAT, [6])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [4])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[6])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
axis=0,
sorted=1)
return ([node], [x], [y, y_ind, x_ind, count])
@onnx_test()
def unique_dynamic_sorted_3D_test():
x = helper.make_tensor_value_info('X', TensorProto.INT64, [4, 4, 4])
y = helper.make_tensor_value_info('Y', TensorProto.INT64, [16])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [16])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[64])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [16])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
sorted=1)
return ([node], [x], [y, y_ind, x_ind, count])
@onnx_test()
def unique_dynamic_unsorted_test():
x = helper.make_tensor_value_info('X', TensorProto.FLOAT, [6])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [4])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[6])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
axis=0,
sorted=0)
return ([node], [x], [y, y_ind, x_ind, count])
@onnx_test()
def unique_sorted_test():
x = helper.make_tensor('X', TensorProto.FLOAT, [6], [2, 1, 1, 3, 4, 3])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [4])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[6])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
axis=0,
sorted=1)
return ([node], [], [y, y_ind, x_ind, count], [x])
@onnx_test()
def unique_unsorted_test():
x = helper.make_tensor('X', TensorProto.FLOAT, [6], [2, 1, 1, 3, 4, 3])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4])
y_ind = helper.make_tensor_value_info('indices', TensorProto.INT64, [4])
x_ind = helper.make_tensor_value_info('inverse_indices', TensorProto.INT64,
[6])
count = helper.make_tensor_value_info('counts', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Unique',
inputs=['X'],
outputs=['Y', 'indices', 'inverse_indices', 'counts'],
axis=0,
sorted=0)
return ([node], [], [y, y_ind, x_ind, count], [x])
@onnx_test() @onnx_test()
def unknown_test(): def unknown_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
......
...@@ -4826,8 +4826,9 @@ TEST_CASE(multinomial_test) ...@@ -4826,8 +4826,9 @@ TEST_CASE(multinomial_test)
migraphx::shape s{migraphx::shape::float_type, {1}}; migraphx::shape s{migraphx::shape::float_type, {1}};
std::vector<float> seed_data = {seed}; std::vector<float> seed_data = {seed};
auto seed_input = mm->add_literal(migraphx::literal(s, seed_data)); auto seed_input = mm->add_literal(migraphx::literal(s, seed_data));
auto rand_dummy = auto rand_dummy = mm->add_literal(
mm->add_literal(migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
std::vector<float>(batch_size * sample_size)});
auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy); auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms); mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms);
...@@ -4978,8 +4979,9 @@ TEST_CASE(multinomial_int64_test) ...@@ -4978,8 +4979,9 @@ TEST_CASE(multinomial_int64_test)
auto seed_input = mm->add_literal(migraphx::literal(s, data)); auto seed_input = mm->add_literal(migraphx::literal(s, data));
// static size // static size
auto rand_dummy = auto rand_dummy = mm->add_literal(
mm->add_literal(migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
std::vector<float>(batch_size * sample_size)});
auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy); auto randoms = mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
mm->add_instruction(migraphx::make_op("multinomial", {{"dtype", dtype}}), cdf, randoms); mm->add_instruction(migraphx::make_op("multinomial", {{"dtype", dtype}}), cdf, randoms);
auto prog = optimize_onnx("multinomial_int64_test.onnx"); auto prog = optimize_onnx("multinomial_int64_test.onnx");
...@@ -8604,6 +8606,86 @@ TEST_CASE(undefined_test) ...@@ -8604,6 +8606,86 @@ TEST_CASE(undefined_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(unique_dynamic_sorted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = mm->add_parameter("X", s);
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 1}, {"axis", 0}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_ind, x_ind, count});
auto prog = migraphx::parse_onnx("unique_dynamic_sorted_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unique_dynamic_sorted_3D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int64_type, {4, 4, 4}};
auto x = mm->add_parameter("X", s);
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 1}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_ind, x_ind, count});
auto prog = migraphx::parse_onnx("unique_dynamic_sorted_3D_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unique_sorted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_x{migraphx::shape::float_type, {6}};
std::vector<float> x_data = {2, 1, 1, 3, 4, 3};
auto x = mm->add_literal(migraphx::literal(s_x, x_data));
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 1}, {"axis", 0}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_idx, x_idx, count});
auto prog = migraphx::parse_onnx("unique_sorted_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unique_unsorted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_x{migraphx::shape::float_type, {6}};
std::vector<float> x_data = {2, 1, 1, 3, 4, 3};
auto x = mm->add_literal(migraphx::literal(s_x, x_data));
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 0}, {"axis", 0}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_idx, x_idx, count});
auto prog = migraphx::parse_onnx("unique_unsorted_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unknown_test) TEST_CASE(unknown_test)
{ {
migraphx::program p; migraphx::program p;
......
 unique_dynamic_sorted_3D_test:Ö
?
XYindicesinverse_indicescounts"Unique*
sorted unique_dynamic_sorted_3D_testZ
X



b
Y

b
indices

b
inverse_indices

@b
counts

B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment