Commit fe493c28 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-gsg

parents ba0b3794 cce35871
......@@ -33,7 +33,6 @@
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/type_name.hpp>
#include <utility>
#include <iostream>
......
......@@ -31,7 +31,6 @@
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/type_name.hpp>
#include <utility>
#include <iostream>
......
......@@ -78,7 +78,9 @@ struct concat_compiler : compiler<concat_compiler>
options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(ctx, axis, options.inputs);
vectorize vec{};
if(axis != v.at("axis").to<std::size_t>())
vec = vectorize::elements(ctx, axis, options.inputs);
options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(
......
......@@ -32,7 +32,7 @@ namespace gpu {
struct mlir_compiler : compiler<mlir_compiler>
{
std::vector<std::string> names() const { return {"gpu::mlir_conv"}; }
std::vector<std::string> names() const { return {"gpu::mlir_op"}; }
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
......
......@@ -60,15 +60,6 @@ __global__ void reduce_kernel(void* input_p, void* output_p)
)__migraphx__";
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
return get_reduce_elements(to_shapes(inputs));
}
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens)
{
......@@ -86,9 +77,28 @@ static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>&
return reduce_lens;
}
static std::string get_reduce_algo(const std::vector<shape>& inputs)
template <class T>
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
std::fill(lens.begin(), lens.end(), 1);
for(const auto& axis : axes)
lens[axis] = s.lens()[axis];
return shape{s.type(), lens};
}
template <class T>
static shape get_output_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
for(const auto& axis : axes)
lens[axis] = 1;
return shape{s.type(), lens};
}
template <class ReduceLens>
static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens rlens)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride
auto min_stride = std::inner_product(
......@@ -103,11 +113,27 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs)
return "block";
}
struct reduce_compiler : compiler<reduce_compiler>
static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
return get_reduce_algo(inputs, rlens);
}
struct simple_reduce_compiler : compiler<simple_reduce_compiler>
{
std::vector<std::string> names() const
{
return {"reduce", "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_prod"};
return {"simple_reduce",
"reduce_sum",
"reduce_mean",
"reduce_max",
"reduce_min",
"reduce_prod"};
}
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
......@@ -157,44 +183,108 @@ struct reduce_compiler : compiler<reduce_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
value v = value::object{};
if(op.name() == "reduce_sum")
{
v["reduction"] = "op::sum{}";
}
else if(op.name() == "reduce_mean")
{
auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type();
v["reduction"] = "op::sum{}";
std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
v["read"] = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type},
reduce_type))
v["read"] = mean;
else
v["write"] = mean;
reduce_op r{};
r.set(ins, op);
v["reduction"] = r.reduction;
v["read"] = r.read;
v["write"] = r.write;
v["init"] = r.init;
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
else if(op.name() == "reduce_max")
};
static const char* const fused_reduce_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
fused_reduce<reduce::${algo}, ${reduced}>(y, partial(${lambda})(xs...));
});
}
}
} // namespace migraphx
)__migraphx__";
struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{
std::vector<std::string> names() const { return {"fused_reduce"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
v["reduction"] = "op::max{}";
v["init"] = "lowest{}";
}
else if(op.name() == "reduce_min")
auto axes = v.at("axes").to_vector<std::size_t>();
auto virtual_inputs = inputs;
virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes));
virtual_inputs.push_back(get_output_shape(inputs.front(), axes));
virtual_inputs = reduce_dims(virtual_inputs);
auto reduce_output_shape = virtual_inputs.back();
virtual_inputs.pop_back();
auto reduction_shape = virtual_inputs.back();
virtual_inputs.pop_back();
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = virtual_inputs;
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
auto nelements = reduce_output_shape.elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs, reduction_shape.lens()));
if(algo == "block")
{
v["reduction"] = "op::min{}";
v["init"] = "highest{}";
// Vectorize if the axis is a reduction axis
if(reduce_output_shape.lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = reduction_shape.elements() / vec.size;
auto block_size = compute_block_size(relements, 256);
if(relements >= block_size * 256)
algo = "block_large";
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
}
else if(op.name() == "reduce_prod")
else if(algo == "lane")
{
v["reduction"] = "op::product{}";
v["init"] = "1";
options.set_launch_params(v, compute_global_for(ctx, nelements, 256));
}
else
{
MIGRAPHX_THROW("Unsupported reduce");
MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.kernel_name = v.get("kernel", "reduce_kernel");
auto src = interpolate_string(
fused_reduce_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"algo", algo},
{"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"},
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(not ins->module_inputs().empty());
auto v = op.to_value();
auto* rm = ins->module_inputs().front();
v["preamble"] = generate_reduce(*rm, "fused_reduce_op");
v["lambda"] = "MIGRAPHX_LIFT(fused_reduce_op)";
v["kernel"] = generate_name_from_ops(*rm) + "_kernel";
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
......
......@@ -204,6 +204,14 @@ constexpr auto compose(Fs... fs)
})(fs...);
}
template <class F>
constexpr auto partial(F f)
{
return [=](auto... xs) {
return [=](auto&&... ys) { return f(xs..., static_cast<decltype(ys)>(ys)...); };
};
}
template <class... Ts>
constexpr auto pack(Ts... xs)
{
......
......@@ -241,6 +241,12 @@ struct index
}
};
#ifdef MIGRAPHX_NLOCAL
#define MIGRAPHX_GLOBAL \
__global__ __attribute__((amdgpu_flat_work_group_size(MIGRAPHX_NLOCAL, MIGRAPHX_NLOCAL)))
#else
#define MIGRAPHX_GLOBAL __global__
#endif
inline __device__ __attribute__((const)) index make_index()
{
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
......
......@@ -174,6 +174,25 @@ struct inner_storage_tag
template <class T>
using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>;
template <class Size, class F>
struct lazy_inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
constexpr lazy_inner_storage<Size, F> make_lazy_inner_storage(Size, F f)
{
return {{}, f};
}
template <class R, class F>
struct storage_access : F
{
......@@ -278,6 +297,14 @@ struct reducer_base
});
}
template <class F>
__device__ auto lazy_inner(F f) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
});
}
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
......@@ -396,25 +423,6 @@ struct block_large
index idx;
Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {{}, {f}};
}
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
......@@ -439,7 +447,7 @@ struct block_large
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
}
};
......@@ -469,25 +477,6 @@ struct lane
index idx;
Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {{}, {f}};
}
template <class Op, class T, class Read, class N, class U, class... Us>
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{
......@@ -518,7 +507,7 @@ struct lane
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
}
};
template <class Slicer>
......@@ -577,5 +566,21 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu
});
}
template <class Algo, class Reduced, class Output, class F>
__device__ void fused_reduce(Output output, F f)
{
Algo::template run<Reduced>([&](auto out_idx, auto r) {
auto result = f(r);
if constexpr(reduce::is_inner_storage<decltype(result)>{})
{
r.inner([&](auto& y, auto x) { y = x; })(output, result);
}
else
{
r.outer([&] { output[out_idx] = implicit_conversion(result); });
}
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
......@@ -30,6 +30,7 @@
#include <mlir-c/BuiltinTypes.h>
#include <mlir-c/Diagnostics.h>
#include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/Dialect/Rock.h>
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h>
#include <mutex>
......@@ -55,12 +56,16 @@
#include <migraphx/permutation.hpp>
#include <deque>
#include <variant>
#include <fstream>
#include <sstream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG);
#ifdef MIGRAPHX_MLIR
template <class T, class F, F f> // NOLINT
......@@ -124,6 +129,8 @@ using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
using mlir_region = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRegion, mlirRegionDestroy);
using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockDestroy);
using mlir_pass_manager = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy);
using mlir_tuning_table = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningTable,
mlirRockTuningTableDestroy);
std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; }
......@@ -455,7 +462,7 @@ struct mlir_program
auto ops = create_operation_state("func.func");
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", std::string("main")},
{"sym_name", sym_name},
{"kernel", std::string("mixr")},
{"arch", target_arch}});
ops.add_region(std::move(region));
......@@ -498,11 +505,25 @@ struct mlir_program
return ins->get_shape();
}
static std::string get_symbol_name(const module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() == "convolution" or ins->name() == "dot")
{
return "mlir_" + ins->name();
}
}
return "main";
}
void parse(const module& m)
{
sym_name = get_symbol_name(m);
auto mbody = mlirModuleGetBody(mmodule.get());
std::unordered_map<instruction_ref, MlirValue> ins_map;
auto fbody = insert(mbody, m, ins_map);
for(auto ins : iterator_for(m))
{
if(ins->name() == "@param")
......@@ -512,16 +533,13 @@ struct mlir_program
ops.add_attribute_value(get_operator_value(ins->get_operator()));
if(ins->name() != "@return")
ops.add_results({get_shape(ins)});
if(ins->name() == "convolution")
if(ins->name() == "convolution" or ins->name() == "dot")
{
pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops
auto target_chip = trim(split_string(target_arch, ':').front());
bool xdlops = contains(get_xdlops_archs(), target_chip);
std::string tuned = get_tune_params(xdlops);
if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}});
if(xdlops)
ops.add_attributes({{"xdlopsV2", true}});
}
......@@ -542,15 +560,19 @@ struct mlir_program
code_object_op compile() MIGRAPHX_TIDY_CONST
{
mlir_pass_manager pm{mlirPassManagerCreate(ctx.get())};
mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())};
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm.get());
mlirMIGraphXAddHighLevelPipeline(pm_front.get());
mlirPassManagerRun(pm_front.get(), mmodule.get());
// 2nd pipeline to call
mlirMIGraphXAddBackendPipeline(pm.get(), target_arch.c_str());
mlirPassManagerRun(pm.get(), mmodule.get());
get_module_tuned();
mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str());
mlirPassManagerRun(pm_back.get(), mmodule.get());
code_object_op op{};
op.symbol_name = "main";
op.symbol_name = sym_name;
op.code_object = get_binary();
std::tie(op.global, op.local) = get_launch_params();
return op;
......@@ -578,7 +600,74 @@ struct mlir_program
MIGRAPHX_THROW("Failed to compile mlir program");
}
std::string get_tune_params(bool xdlops) { return get_mlir_perf_for_conv(pp, xdlops); }
std::string get_tune_params(bool xdlops) const { return get_mlir_perf_for_conv(pp, xdlops); }
// This function appends to tuning cfg file that could be
// used with rocMLIR tuning scripts.
void dump_tuning_cfg(const char* prob_config) const
{
std::string tuning_cfg_path = string_value_of(MIGRAPHX_MLIR_TUNING_CFG{});
if(!tuning_cfg_path.empty())
{
std::vector<std::string> tokens = split_string(prob_config, '\t');
std::string prob = tokens[1];
if(starts_with(prob, "conv"))
{
tuning_cfg_path += ".conv";
}
else
{
tuning_cfg_path += ".gemm";
}
std::ofstream tuning_cfg(tuning_cfg_path, std::ios::app);
tuning_cfg << prob << std::endl;
}
}
static mlir_tuning_table create_tuning_table()
{
mlir_tuning_table tuning_table{mlirRockTuningTableCreate()};
std::string tuning_db_path = string_value_of(MIGRAPHX_MLIR_TUNING_DB{});
if(!tuning_db_path.empty())
{
std::ifstream tuning_db_tsv(tuning_db_path);
if(tuning_db_tsv)
{
std::string line;
while(std::getline(tuning_db_tsv, line))
{
std::vector<std::string> tokens = split_string(line, '\t');
std::string arch = tokens[0];
std::string prob = tokens[1];
std::string perf = tokens[2];
std::string key = arch.append("\t").append(prob);
mlirRockTuningUpdateTable(tuning_table.get(), key.c_str(), perf.c_str(), 1.0);
}
}
}
else
{
std::cerr
<< "WARNING: MLIR tuning db not found. Please set MIGRAPHX_MLIR_TUNING_DB for "
"optimal performance."
<< std::endl;
}
return tuning_table;
}
bool get_module_tuned() const
{
static mlir_tuning_table tuning_table = create_tuning_table();
if(!mlirRockTuningSetFromTable(tuning_table.get(), mmodule.get()))
{
const char* prob_config = mlirRockTuningGetKey(tuning_table.get(), mmodule.get());
std::stringstream key(prob_config);
std::cerr << "fails to set param on" << prob_config << std::endl;
dump_tuning_cfg(prob_config);
return false;
}
return true;
}
mlir_context ctx;
MlirLocation location;
......@@ -586,6 +675,7 @@ struct mlir_program
problem_params pp;
std::deque<std::string> strings{};
std::string target_arch;
std::string sym_name;
};
std::string dump_mlir(const module& m)
......
......@@ -26,13 +26,13 @@
#include <migraphx/check_context.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
......@@ -40,7 +40,7 @@
#include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/promote_literals.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_gelu.hpp>
......@@ -48,9 +48,9 @@
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp>
......@@ -75,6 +75,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass
{
......@@ -103,6 +104,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
// clang-format off
return
{
enable_pass(options.split_single_dyn_dim, split_single_dyn_dim{}),
enable_pass(options.split_single_dyn_dim, dead_code_elimination{}),
normalize_ops{},
dead_code_elimination{},
simplify_qdq{},
......@@ -132,6 +135,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
optimize_module{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{},
fuse_mlir{&ctx},
dead_code_elimination{},
fuse_ck{&ctx},
......@@ -153,6 +158,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
compile_ops{&ctx},
dead_code_elimination{},
promote_literals{},
dead_code_elimination{},
write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"},
......
......@@ -31,10 +31,9 @@ set_target_properties(migraphx_ref PROPERTIES EXPORT_NAME ref)
rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION})
find_path(BLAZE_INCLUDE blaze/Blaze.h)
find_package(Threads)
rocm_clang_tidy_check(migraphx_ref)
target_link_libraries(migraphx_ref migraphx Threads::Threads)
target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
......
......@@ -110,7 +110,7 @@ function(add_test_executable TEST_NAME)
add_test_command(${TEST_NAME} ${TEST_COMMAND})
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx migraphx_onnx)
target_link_libraries(${TEST_NAME} migraphx migraphx_onnx migraphx_ref)
target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable)
......@@ -163,7 +163,7 @@ foreach(ONNX_TEST ${ONNX_TESTS})
set(TEST_NAME test_${BASE_NAME})
add_executable(${TEST_NAME} ${ONNX_TEST})
rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx_onnx)
target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref)
target_include_directories(${TEST_NAME} PUBLIC include)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_dependencies(tests ${TEST_NAME})
......@@ -218,3 +218,10 @@ test_headers(migraphx/ref ${CMAKE_SOURCE_DIR}/src/targets/ref/include/migraphx/r
if(MIGRAPHX_ENABLE_GPU)
test_headers(migraphx/gpu ${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/*.hpp)
endif()
if(MIGRAPHX_ENABLE_CPU)
test_headers(migraphx/cpu ${CMAKE_SOURCE_DIR}/src/targets/cpu/include/migraphx/cpu/*.hpp)
endif()
if(MIGRAPHX_ENABLE_FPGA)
test_headers(migraphx/fpga ${CMAKE_SOURCE_DIR}/src/targets/fpga/include/migraphx/fpga/*.hpp)
endif()
......@@ -25,7 +25,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
rocm_clang_tidy_check(${NAME})
target_link_libraries(${NAME} migraphx_c migraphx)
target_link_libraries(${NAME} migraphx_c migraphx migraphx_all_targets)
target_include_directories(${NAME} PUBLIC ../include)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME})
......@@ -59,7 +59,7 @@ if(MIGRAPHX_ENABLE_GPU)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_gpu hip::host)
target_link_libraries(test_api_gpu)
add_api_test(custom_op_gpu test_custom_op_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_custom_op_gpu hip::host)
target_link_libraries(test_api_custom_op_gpu)
endif()
......@@ -36,7 +36,7 @@ bool create_shapes(bool dynamic_allowed)
try
{
shape a{shape::int64_type, {3}};
shape b{shape::float_type, {{3, 6, 0}, {4, 4, 0}}};
shape b{shape::float_type, {{3, 6}, {4, 4}}};
auto op = migraphx::make_op("add");
migraphx::check_shapes{{a, b}, op, dynamic_allowed}.has(2);
return true;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::fuse_reduce{}, migraphx::dead_code_elimination{}});
}
bool all_instructions_are_local(const migraphx::module& m)
{
return std::all_of(m.begin(), m.end(), [&](const auto& ins) {
return std::all_of(ins.inputs().begin(), ins.inputs().end(), [&](auto input) {
return m.has_instruction(input);
});
});
}
template <class F>
migraphx::instruction_ref add_reduce(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
const std::vector<int64_t>& axes,
F f)
{
auto* rm = p.create_module(name);
auto* mm = p.get_main_module();
rm->set_bypass();
std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
return rm->add_parameter(
"x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type(), input->get_shape().lens()});
});
auto r = f(rm, params, axes);
rm->add_return({r});
EXPECT(all_instructions_are_local(*rm));
return mm->add_instruction(migraphx::make_op("fused_reduce", {{"axes", axes}}), inputs, {rm});
}
inline auto single_reduce(const std::string& name)
{
return [=](auto* rm, const auto& inputs, const auto& axes) {
return rm->add_instruction(migraphx::make_op(name, {{"axes", axes}}), inputs);
};
}
TEST_CASE(single)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), y);
mm->add_return({rsum1, rsum2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum1 = add_reduce(p2, "main:reduce_sum0", {x}, {1}, single_reduce("reduce_sum"));
auto rsum2 = add_reduce(p2, "main:reduce_sum1", {y}, {1}, single_reduce("reduce_sum"));
mm->add_return({rsum1, rsum2});
}
EXPECT(p1 == p2);
}
TEST_CASE(pointwise_reduce)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), add);
mm->add_return({rsum});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum = add_reduce(
p2,
"main:pointwise0:main:reduce_sum0",
{x, y},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto add =
add_pointwise(p2, rm, "main:pointwise0", inputs, single_pointwise("add"));
return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add);
});
mm->add_return({rsum});
}
EXPECT(p1 == p2);
}
TEST_CASE(reduce_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum);
auto add = add_pointwise(p1, "main:pointwise0", {rsumb, y}, single_pointwise("add"));
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_reduce(
p2,
"main:reduce_sum0:main:pointwise0",
{x, y},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto rsumb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum);
return add_pointwise(
p2, rm, "main:pointwise0", {rsumb, inputs[1]}, single_pointwise("add"));
});
mm->add_return({add});
}
EXPECT(p1 == p2);
}
TEST_CASE(reduce_reduce)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum);
auto rsumdiff = add_pointwise(p1, "main:pointwise0", {rsumb, x}, single_pointwise("sub"));
auto rsum2 =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), rsumdiff);
auto sqrt = add_pointwise(p1, "main:pointwise1", {rsum2}, single_pointwise("sqrt"));
mm->add_return({sqrt});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto sqrt = add_reduce(
p2,
"main:reduce_sum1:main:reduce_sum0:main:pointwise0:main:pointwise1",
{x},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto rsumb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum);
auto rsumdiff = add_pointwise(
p2, rm, "main:pointwise0", {rsumb, inputs[0]}, single_pointwise("sub"));
auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
rsumdiff);
return add_pointwise(p2, rm, "main:pointwise1", {rsum2}, single_pointwise("sqrt"));
});
mm->add_return({sqrt});
}
EXPECT(p1 == p2);
}
TEST_CASE(reduce_reduce_mismatch_axis)
{
migraphx::shape s{migraphx::shape::float_type, {4, 2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), rsum1);
mm->add_return({rsum2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum1 = add_reduce(p2, "main:reduce_sum0", {x}, {1}, single_reduce("reduce_sum"));
auto rsum2 = add_reduce(p2, "main:reduce_sum1", {rsum1}, {2}, single_reduce("reduce_sum"));
mm->add_return({rsum2});
}
EXPECT(p1 == p2);
}
TEST_CASE(pointwise_reduce_broadcast)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto sqrt = add_pointwise(p1, "main:pointwise0", {rsum1}, single_pointwise("sqrt"));
auto sqrtb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt);
auto add1 = add_pointwise(p1, "main:pointwise1", {sqrtb, x}, single_pointwise("add"));
auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), add1);
auto add2 = add_pointwise(p1, "main:pointwise2", {rsum2, rsum1}, single_pointwise("add"));
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add2 = add_reduce(
p2,
"main:pointwise0:main:pointwise1:main:reduce_sum1:main:pointwise2:main:reduce_sum0",
{x},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto sqrt =
add_pointwise(p2, rm, "main:pointwise0", {rsum1}, single_pointwise("sqrt"));
auto sqrtb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt);
auto add1 = add_pointwise(
p2, rm, "main:pointwise1", {sqrtb, inputs[0]}, single_pointwise("add"));
auto rsum2 =
rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add1);
return add_pointwise(
p2, rm, "main:pointwise2", {rsum2, rsum1}, single_pointwise("add"));
});
mm->add_return({add2});
}
EXPECT(p1 == p2);
}
TEST_CASE(reduce_reduce_broadcast)
{
migraphx::shape s{migraphx::shape::float_type, {4, 2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum1 = add_reduce(p1, "test:reduce_sum0", {x}, {1}, single_reduce("reduce_sum"));
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum1);
auto add = add_reduce(
p1,
"test:reduce_sum1",
{rsumb, x},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto add2 =
add_pointwise(p1, rm, "test:pointwise0", inputs, single_pointwise("add"));
return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add2);
});
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum = add_reduce(
p2,
"test:reduce_sum1:test:reduce_sum0",
{x},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto rsumb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum1);
auto add = add_pointwise(
p2, rm, "test:pointwise0", {rsumb, inputs[0]}, single_pointwise("add"));
return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add);
});
mm->add_return({rsum});
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -27,7 +27,7 @@
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/target.hpp>
TEST_CASE(tuple_to_from_gpu)
TEST_CASE(tuple_from_gpu)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 4}};
......@@ -47,4 +47,23 @@ TEST_CASE(tuple_to_from_gpu)
EXPECT(result2 == p2_data);
}
TEST_CASE(tuple_to_gpu)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 4}};
std::vector<float> p1_data = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6};
std::vector<int> p2_data = {1, 2, 3, 4, 5, 6, 7, 8};
auto p1 = migraphx::argument{s1, p1_data.data()};
auto p2 = migraphx::argument{s2, p2_data.data()};
auto p_gpu = migraphx::gpu::to_gpu(migraphx::argument({p1, p2}));
auto p_host = migraphx::gpu::from_gpu(p_gpu);
std::vector<migraphx::argument> results = p_host.get_sub_objects();
std::vector<float> result1;
results[0].visit([&](auto output) { result1.assign(output.begin(), output.end()); });
std::vector<int> result2;
results[1].visit([&](auto output) { result2.assign(output.begin(), output.end()); });
EXPECT(result1 == p1_data);
EXPECT(result2 == p2_data);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -140,7 +140,7 @@ TEST_CASE(conv)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
func.func @mlir_convolution(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
return %0 : tensor<1x2x2x2xf32>
}
......@@ -163,7 +163,7 @@ TEST_CASE(conv_add_relu)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
func.func @mlir_convolution(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
......@@ -187,4 +187,30 @@ module {
EXPECT(verify_mlir(m));
}
TEST_CASE(dot_add)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : tensor<1x5x4xf32>, tensor<1x4x3xf32> -> tensor<1x5x3xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto arg2 = m.add_parameter("arg2", {migraphx::shape::float_type, {1, 5, 3}});
auto conv = m.add_instruction(migraphx::make_op("dot"), arg0, arg1);
auto add = m.add_instruction(migraphx::make_op("add"), conv, arg2);
m.add_return({add});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -30,12 +30,12 @@
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
migraphx::module_ref mm,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
auto* pm = p.create_module(name);
auto* mm = p.get_main_module();
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
......@@ -47,6 +47,15 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p,
return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm});
}
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
return add_pointwise(p, p.get_main_module(), name, inputs, f);
}
inline auto single_pointwise(const std::string& name)
{
return [=](auto* pm, const auto& inputs) {
......
......@@ -186,14 +186,13 @@ TEST_CASE(argmax_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"x",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {5, 5, 0}, {6, 6, 0}}});
"x", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}, {5, 5}, {6, 6}}});
auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("argmax_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -296,8 +295,7 @@ TEST_CASE(averagepool_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}, {5, 5, 0}}});
"0", {migraphx::shape::float_type, {{1, 4}, {3, 3}, {5, 5}, {5, 5}, {5, 5}}});
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0, 0, 0}},
......@@ -307,7 +305,7 @@ TEST_CASE(averagepool_dyn_test)
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = migraphx::parse_onnx("averagepool_dyn_test.onnx", options);
EXPECT(p == prog);
}
......@@ -315,7 +313,7 @@ TEST_CASE(averagepool_dyn_test)
TEST_CASE(averagepool_dyn_autopad_error_test)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
EXPECT(test::throws(
[&] { migraphx::parse_onnx("averagepool_dyn_autopad_error_test.onnx", options); }));
}
......@@ -323,7 +321,7 @@ TEST_CASE(averagepool_dyn_autopad_error_test)
TEST_CASE(averagepool_dyn_asym_padding_error_test)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
EXPECT(test::throws(
[&] { migraphx::parse_onnx("averagepool_dyn_asym_padding_error_test.onnx", options); }));
}
......@@ -331,7 +329,7 @@ TEST_CASE(averagepool_dyn_asym_padding_error_test)
TEST_CASE(averagepool_dyn_cip_error_test)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
EXPECT(test::throws(
[&] { migraphx::parse_onnx("averagepool_dyn_cip_error_test.onnx", options); }));
}
......@@ -589,15 +587,14 @@ TEST_CASE(binary_dyn_brcst_prelu_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}});
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto ret = add_common_op(*mm, migraphx::make_op("prelu"), {l0, l1});
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = migraphx::parse_onnx("binary_dyn_brcst_prelu_test.onnx", options);
EXPECT(p == prog);
......@@ -609,14 +606,13 @@ TEST_CASE(binary_dyn_brcst_add_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, {4, 5}});
auto l1 = mm->add_parameter(
"1",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}});
"1", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}});
auto ret = add_common_op(*mm, migraphx::make_op("add"), {l0, l1});
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = migraphx::parse_onnx("binary_dyn_brcst_add_test.onnx", options);
EXPECT(p == prog);
......@@ -625,7 +621,7 @@ TEST_CASE(binary_dyn_brcst_add_test)
TEST_CASE(binary_dyn_brcst_attr_error_test)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
EXPECT(test::throws(
[&] { migraphx::parse_onnx("binary_dyn_brcst_attr_error_test.onnx", options); }));
}
......@@ -635,8 +631,7 @@ TEST_CASE(binary_dyn_brcst_mul_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}});
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 1}});
auto bl1 = mm->add_instruction(
......@@ -648,7 +643,7 @@ TEST_CASE(binary_dyn_brcst_mul_test)
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = migraphx::parse_onnx("binary_dyn_brcst_mul_test.onnx", options);
EXPECT(p == prog);
......@@ -845,15 +840,15 @@ TEST_CASE(concat_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {1, 4, 0}, {3, 3, 0}}});
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {1, 4}, {3, 3}}});
auto l1 = mm->add_parameter(
"1", migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {1, 4, 0}, {3, 3, 0}}});
"1", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {1, 4}, {3, 3}}});
auto ret = mm->add_instruction(migraphx::make_op("concat"), l0, l1);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("concat_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -1120,8 +1115,8 @@ TEST_CASE(conv_dynamic_batch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0", {migraphx::shape::float_type, {{1, 6, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}});
auto l0 =
mm->add_parameter("0", {migraphx::shape::float_type, {{1, 6}, {3, 3}, {5, 5}, {5, 5}}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}});
auto c0 = mm->add_instruction(
migraphx::make_op("convolution",
......@@ -1131,7 +1126,7 @@ TEST_CASE(conv_dynamic_batch_test)
mm->add_return({c0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 6, 0};
options.default_dyn_dim_value = {1, 6};
auto prog = migraphx::parse_onnx("conv_dynamic_batch_test.onnx", options);
EXPECT(p == prog);
......@@ -1141,8 +1136,8 @@ TEST_CASE(conv_dynamic_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x0 = mm->add_parameter(
"0", {migraphx::shape::float_type, {{1, 6, 0}, {3, 3, 0}, {32, 32, 0}, {32, 32, 0}}});
auto x0 =
mm->add_parameter("0", {migraphx::shape::float_type, {{1, 6}, {3, 3}, {32, 32}, {32, 32}}});
auto x1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto x2 = mm->add_parameter("2", {migraphx::shape::float_type, {1}});
auto x3 = mm->add_instruction(migraphx::make_op("convolution"), x0, x1);
......@@ -1151,7 +1146,7 @@ TEST_CASE(conv_dynamic_bias_test)
mm->add_return({x5});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 6, 0};
options.default_dyn_dim_value = {1, 6};
auto prog = migraphx::parse_onnx("conv_dynamic_bias_test.onnx", options);
EXPECT(p == prog);
}
......@@ -1160,8 +1155,8 @@ TEST_CASE(conv_dynamic_img_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0", {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {5, 10, 0}, {5, 10, 0}}});
auto l0 =
mm->add_parameter("0", {migraphx::shape::float_type, {{1, 1}, {3, 3}, {5, 10}, {5, 10}}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}});
auto c0 = mm->add_instruction(
migraphx::make_op("convolution",
......@@ -1171,7 +1166,7 @@ TEST_CASE(conv_dynamic_img_test)
mm->add_return({c0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {5, 10, 0};
options.default_dyn_dim_value = {5, 10};
auto prog = migraphx::parse_onnx("conv_dynamic_img_test.onnx", options);
EXPECT(p == prog);
......@@ -1182,8 +1177,8 @@ TEST_CASE(conv_dynamic_weights_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l1 = mm->add_parameter(
"1", {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 4, 0}, {2, 4, 0}}});
auto l1 =
mm->add_parameter("1", {migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 4}, {2, 4}}});
auto c0 = mm->add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
......@@ -1192,7 +1187,7 @@ TEST_CASE(conv_dynamic_weights_test)
mm->add_return({c0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 4, 0};
options.default_dyn_dim_value = {2, 4};
auto prog = migraphx::parse_onnx("conv_dynamic_weights_test.onnx", options);
EXPECT(p == prog);
......@@ -1202,10 +1197,10 @@ TEST_CASE(conv_dynamic_img_and_weights_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0", {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {5, 10, 0}, {5, 10, 0}}});
auto l1 = mm->add_parameter(
"1", {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 4, 0}, {2, 4, 0}}});
auto l0 =
mm->add_parameter("0", {migraphx::shape::float_type, {{1, 1}, {3, 3}, {5, 10}, {5, 10}}});
auto l1 =
mm->add_parameter("1", {migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 4}, {2, 4}}});
auto c0 = mm->add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
......@@ -1214,8 +1209,8 @@ TEST_CASE(conv_dynamic_img_and_weights_test)
mm->add_return({c0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {5, 10, 0};
options.map_dyn_input_dims["1"] = {{1, 1, 0}, {3, 3, 0}, {2, 4, 0}, {2, 4, 0}};
options.default_dyn_dim_value = {5, 10};
options.map_dyn_input_dims["1"] = {{1, 1}, {3, 3}, {2, 4}, {2, 4}};
auto prog = migraphx::parse_onnx("conv_dynamic_img_and_weights_test.onnx", options);
EXPECT(p == prog);
......@@ -1225,8 +1220,8 @@ TEST_CASE(conv_dynamic_batch_same_upper)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0", {migraphx::shape::float_type, {{1, 10, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}});
auto l0 =
mm->add_parameter("0", {migraphx::shape::float_type, {{1, 10}, {3, 3}, {5, 5}, {5, 5}}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}});
auto c0 = mm->add_instruction(
migraphx::make_op("convolution",
......@@ -1236,7 +1231,7 @@ TEST_CASE(conv_dynamic_batch_same_upper)
mm->add_return({c0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 0};
options.default_dyn_dim_value = {1, 10};
auto prog = migraphx::parse_onnx("conv_dynamic_batch_same_upper_test.onnx", options);
EXPECT(p == prog);
......@@ -1246,8 +1241,8 @@ TEST_CASE(conv_dynamic_img_same_upper)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0", {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {5, 10, 0}, {5, 10, 0}}});
auto l0 =
mm->add_parameter("0", {migraphx::shape::float_type, {{1, 1}, {3, 3}, {5, 10}, {5, 10}}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}});
auto c0 = mm->add_instruction(
migraphx::make_op("convolution",
......@@ -1260,7 +1255,7 @@ TEST_CASE(conv_dynamic_img_same_upper)
mm->add_return({c0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {5, 10, 0};
options.default_dyn_dim_value = {5, 10};
auto prog = migraphx::parse_onnx("conv_dynamic_img_same_upper_test.onnx", options);
EXPECT(p == prog);
......@@ -1271,8 +1266,8 @@ TEST_CASE(conv_dynamic_kernel_same_lower)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto l1 = mm->add_parameter(
"1", {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 4, 0}, {2, 4, 0}}});
auto l1 =
mm->add_parameter("1", {migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 4}, {2, 4}}});
auto c0 = mm->add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}},
......@@ -1284,7 +1279,7 @@ TEST_CASE(conv_dynamic_kernel_same_lower)
mm->add_return({c0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 4, 0};
options.default_dyn_dim_value = {2, 4};
auto prog = migraphx::parse_onnx("conv_dynamic_kernel_same_lower_test.onnx", options);
EXPECT(p == prog);
}
......@@ -2030,14 +2025,13 @@ TEST_CASE(flatten_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}});
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}});
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto ret = mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), c0);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("flatten_dyn_test.onnx", options);
EXPECT(p == prog);
}
......@@ -2087,11 +2081,9 @@ TEST_CASE(gather_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"data",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {5, 5, 0}, {6, 6, 0}}});
"data", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}, {5, 5}, {6, 6}}});
auto l1 = mm->add_parameter(
"indices",
migraphx::shape{migraphx::shape::int32_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}});
"indices", migraphx::shape{migraphx::shape::int32_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}});
auto cont_l0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto cont_l1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
......@@ -2101,7 +2093,7 @@ TEST_CASE(gather_dyn_test)
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("gather_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -2181,15 +2173,15 @@ TEST_CASE(gathernd_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data",
migraphx::shape{migraphx::shape::float_type, {{2, 4, 2}, {2, 4}}});
auto l0 = mm->add_parameter(
"data", migraphx::shape{migraphx::shape::float_type, {{2, 4, {2}}, {2, 4}}});
auto l1 = mm->add_parameter("indices",
migraphx::shape{migraphx::shape::int64_type, {{1, 3}, {2, 2}}});
auto r = mm->add_instruction(migraphx::make_op("gathernd"), l0, l1);
mm->add_return({r});
migraphx::onnx_options options;
options.map_dyn_input_dims["data"] = {{2, 4, 2}, {2, 4}};
options.map_dyn_input_dims["data"] = {{2, 4, {2}}, {2, 4}};
options.map_dyn_input_dims["indices"] = {{1, 3}, {2, 2}};
auto prog = migraphx::parse_onnx("gathernd_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -2318,9 +2310,9 @@ TEST_CASE(gemm_dyn_inner_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"A", migraphx::shape{migraphx::shape::float_type, {{1, 10, 8}, {6, 6, 0}}});
"A", migraphx::shape{migraphx::shape::float_type, {{1, 10, {8}}, {6, 6}}});
auto l1 = mm->add_parameter(
"B", migraphx::shape{migraphx::shape::float_type, {{1, 10, 8}, {7, 7, 0}}});
"B", migraphx::shape{migraphx::shape::float_type, {{1, 10, {8}}, {7, 7}}});
auto alpha = 0.5f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
......@@ -2329,7 +2321,7 @@ TEST_CASE(gemm_dyn_inner_test)
mm->add_return({dot});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 8};
options.default_dyn_dim_value = {1, 10, {8}};
auto prog = migraphx::parse_onnx("gemm_dyn_inner_test.onnx", options);
EXPECT(p == prog);
}
......@@ -2339,7 +2331,7 @@ TEST_CASE(gemm_dyn_outer_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"A", migraphx::shape{migraphx::shape::float_type, {{5, 5, 0}, {5, 10, 7}}});
"A", migraphx::shape{migraphx::shape::float_type, {{5, 5}, {5, 10, {7}}}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto alpha = 2.f;
auto a_l = mm->add_literal(alpha);
......@@ -2350,7 +2342,7 @@ TEST_CASE(gemm_dyn_outer_test)
mm->add_return({dot});
migraphx::onnx_options options;
options.default_dyn_dim_value = {5, 10, 7};
options.default_dyn_dim_value = {5, 10, {7}};
auto prog = migraphx::parse_onnx("gemm_dyn_outer_test.onnx", options);
EXPECT(p == prog);
}
......@@ -2401,10 +2393,8 @@ TEST_CASE(globalavgpool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto input = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {16, 16}, {16, 16}}});
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"lengths", {16, 16}},
......@@ -2413,7 +2403,7 @@ TEST_CASE(globalavgpool_dyn_test)
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("globalavgpool_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -2440,10 +2430,8 @@ TEST_CASE(globallppool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {16, 32, 0}, {16, 32, 0}}});
auto input = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {16, 32}, {16, 32}}});
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::lpnorm},
{"dyn_global", true},
......@@ -2453,7 +2441,7 @@ TEST_CASE(globallppool_dyn_test)
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {16, 32, 0};
options.default_dyn_dim_value = {16, 32};
auto prog = migraphx::parse_onnx("globallppool_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -2480,10 +2468,8 @@ TEST_CASE(globalmaxpool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {32, 32, 0}, {32, 32, 0}}});
auto input = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {32, 32}, {32, 32}}});
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"lengths", {32, 32}},
......@@ -2492,7 +2478,7 @@ TEST_CASE(globalmaxpool_dyn_test)
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("globalmaxpool_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -3691,16 +3677,16 @@ TEST_CASE(matmul_dyn_mm_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"1", migraphx::shape{migraphx::shape::float_type, {{4, 8, 6}, {7, 7, 0}}});
auto l1 = mm->add_parameter(
"2", migraphx::shape{migraphx::shape::float_type, {{7, 7, 0}, {1, 5, 3}}});
auto l0 =
mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {{4, 8, {6}}, {7, 7}}});
auto l1 =
mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {{7, 7}, {1, 5, {3}}}});
auto ret = migraphx::add_apply_alpha_beta(*mm, {l0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["1"] = {{4, 8, 6}, {7, 7, 0}};
options.map_dyn_input_dims["2"] = {{7, 7, 0}, {1, 5, 3}};
options.map_dyn_input_dims["1"] = {{4, 8, {6}}, {7, 7}};
options.map_dyn_input_dims["2"] = {{7, 7}, {1, 5, {3}}};
auto prog = parse_onnx("matmul_dyn_mm_test.onnx", options);
EXPECT(p == prog);
......@@ -3710,8 +3696,8 @@ TEST_CASE(matmul_dyn_mv_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"1", migraphx::shape{migraphx::shape::float_type, {{4, 8, 6}, {7, 7, 0}}});
auto l0 =
mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {{4, 8, {6}}, {7, 7}}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res = migraphx::add_apply_alpha_beta(*mm, {l0, sl1}, migraphx::make_op("dot"), 1.0f, 0.0f);
......@@ -3719,7 +3705,7 @@ TEST_CASE(matmul_dyn_mv_test)
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["1"] = {{4, 8, 6}, {7, 7, 0}};
options.map_dyn_input_dims["1"] = {{4, 8, {6}}, {7, 7}};
auto prog = parse_onnx("matmul_dyn_mv_test.onnx", options);
EXPECT(p == prog);
......@@ -3731,14 +3717,14 @@ TEST_CASE(matmul_dyn_vm_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = mm->add_parameter(
"2", migraphx::shape{migraphx::shape::float_type, {{7, 7, 0}, {4, 10, 8}}});
"2", migraphx::shape{migraphx::shape::float_type, {{7, 7}, {4, 10, {8}}}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto res = migraphx::add_apply_alpha_beta(*mm, {sl0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["2"] = {{7, 7, 0}, {4, 10, 8}};
options.map_dyn_input_dims["2"] = {{7, 7}, {4, 10, {8}}};
auto prog = parse_onnx("matmul_dyn_vm_test.onnx", options);
EXPECT(p == prog);
......@@ -3748,7 +3734,7 @@ TEST_CASE(matmul_dyn_vv_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{5, 8, 7};
migraphx::shape::dynamic_dimension dd{5, 8, {7}};
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {dd}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {dd}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
......@@ -3769,7 +3755,7 @@ TEST_CASE(matmul_dyn_vv_test)
TEST_CASE(matmul_dyn_broadcast_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
EXPECT(test::throws([&] { migraphx::parse_onnx("matmul_dyn_broadcast_error.onnx", options); }));
}
......@@ -3789,7 +3775,7 @@ TEST_CASE(matmulinteger_test)
TEST_CASE(matmulinteger_dyn_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
EXPECT(test::throws([&] { migraphx::parse_onnx("matmulinteger_dyn_error.onnx", options); }));
}
......@@ -4098,13 +4084,13 @@ TEST_CASE(neg_dynamic_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int64_type, {{1, 10, 0}, {3, 3, 0}}};
migraphx::shape s{migraphx::shape::int64_type, {{1, 10}, {3, 3}}};
auto input = mm->add_parameter("0", s);
auto ret = mm->add_instruction(migraphx::make_op("neg"), input);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 0};
options.default_dyn_dim_value = {1, 10};
auto prog = migraphx::parse_onnx("neg_dynamic_test.onnx", options);
EXPECT(p == prog);
}
......@@ -4140,9 +4126,9 @@ TEST_CASE(nms_dynamic_batch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::float_type, {{1, 10, 0}, {6, 6, 0}, {4, 4, 0}}};
migraphx::shape sb{migraphx::shape::float_type, {{1, 10}, {6, 6}, {4, 4}}};
auto b = mm->add_parameter("boxes", sb);
migraphx::shape ss{migraphx::shape::float_type, {{1, 10, 0}, {1, 1, 0}, {6, 6, 0}}};
migraphx::shape ss{migraphx::shape::float_type, {{1, 10}, {1, 1}, {6, 6}}};
auto s = mm->add_parameter("scores", ss);
migraphx::shape smo{migraphx::shape::int64_type, {1}};
auto mo = mm->add_parameter("max_output_boxes_per_class", smo);
......@@ -4161,7 +4147,7 @@ TEST_CASE(nms_dynamic_batch_test)
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 0};
options.default_dyn_dim_value = {1, 10};
options.use_dyn_output = true;
auto prog = migraphx::parse_onnx("nms_dynamic_batch_test.onnx", options);
......@@ -4172,9 +4158,9 @@ TEST_CASE(nms_dynamic_boxes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::float_type, {{1, 1, 0}, {6, 20, 0}, {4, 4, 0}}};
migraphx::shape sb{migraphx::shape::float_type, {{1, 1}, {6, 20}, {4, 4}}};
auto b = mm->add_parameter("boxes", sb);
migraphx::shape ss{migraphx::shape::float_type, {{1, 1, 0}, {1, 1, 0}, {6, 20, 0}}};
migraphx::shape ss{migraphx::shape::float_type, {{1, 1}, {1, 1}, {6, 20}}};
auto s = mm->add_parameter("scores", ss);
migraphx::shape smo{migraphx::shape::int64_type, {1}};
auto mo = mm->add_parameter("max_output_boxes_per_class", smo);
......@@ -4187,7 +4173,7 @@ TEST_CASE(nms_dynamic_boxes_test)
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {6, 20, 0};
options.default_dyn_dim_value = {6, 20};
options.use_dyn_output = true;
auto prog = migraphx::parse_onnx("nms_dynamic_boxes_test.onnx", options);
......@@ -4200,7 +4186,7 @@ TEST_CASE(nms_dynamic_classes_test)
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::float_type, {1, 6, 4}};
auto b = mm->add_parameter("boxes", sb);
migraphx::shape ss{migraphx::shape::float_type, {{1, 1, 0}, {1, 10, 0}, {6, 6, 0}}};
migraphx::shape ss{migraphx::shape::float_type, {{1, 1}, {1, 10}, {6, 6}}};
auto s = mm->add_parameter("scores", ss);
migraphx::shape smo{migraphx::shape::int64_type, {1}};
auto mo = mm->add_parameter("max_output_boxes_per_class", smo);
......@@ -4213,7 +4199,7 @@ TEST_CASE(nms_dynamic_classes_test)
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 0};
options.default_dyn_dim_value = {1, 10};
options.use_dyn_output = true;
auto prog = migraphx::parse_onnx("nms_dynamic_classes_test.onnx", options);
......@@ -4388,12 +4374,12 @@ TEST_CASE(pad_attr_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{2, 4, 2}, {2, 4, 2}}});
"0", migraphx::shape{migraphx::shape::float_type, {{2, 4, {2}}, {2, 4, {2}}}});
auto ret = mm->add_instruction(migraphx::make_op("pad", {{"pads", {1, 1, 1, 1}}}), x);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 4, 2}, {2, 4, 2}};
options.map_dyn_input_dims["0"] = {{2, 4, {2}}, {2, 4, {2}}};
auto prog = parse_onnx("pad_attr_dyn_test.onnx", options);
EXPECT(p == prog);
}
......@@ -4403,13 +4389,13 @@ TEST_CASE(pad_cnst_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{2, 4, 2}, {2, 4, 2}}});
"0", migraphx::shape{migraphx::shape::float_type, {{2, 4, {2}}, {2, 4, {2}}}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 0, 1}});
auto ret = mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 2, 0, 1}}}), x);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 4, 2}, {2, 4, 2}};
options.map_dyn_input_dims["0"] = {{2, 4, {2}}, {2, 4, {2}}};
auto prog = parse_onnx("pad_cnst_dyn_test.onnx", options);
EXPECT(p == prog);
}
......@@ -4417,7 +4403,7 @@ TEST_CASE(pad_cnst_dyn_test)
TEST_CASE(pad_dyn_reflect_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 4, 2};
options.default_dyn_dim_value = {2, 4, {2}};
EXPECT(test::throws([&] { migraphx::parse_onnx("pad_dyn_reflect_error.onnx", options); }));
}
......@@ -4881,7 +4867,7 @@ TEST_CASE(reducel1_dyn_test)
// a shape with 4 dynamic dimensions
auto l0 = mm->add_parameter("x",
migraphx::shape{migraphx::shape::float_type,
{{3, 3, 0}, {3, 5, 0}, {4, 6, 5}, {5, 7, 6}}});
{{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}});
auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0);
auto sum_ins =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-2}}}), abs_ins);
......@@ -4889,7 +4875,7 @@ TEST_CASE(reducel1_dyn_test)
mm->add_return({sq_ins});
migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{3, 3}, {3, 5}, {4, 6, 5}, {5, 7, 6}};
options.map_dyn_input_dims["x"] = {{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}};
auto prog = migraphx::parse_onnx("reducel1_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -4901,7 +4887,7 @@ TEST_CASE(reducel1_dyn_test)
// No axes given in the onnx file. Parser should default to all axes.
auto l0 = mm->add_parameter("x",
migraphx::shape{migraphx::shape::float_type,
{{3, 3, 0}, {3, 5, 0}, {4, 6, 5}, {5, 7, 6}}});
{{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}});
auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0);
auto sum_ins =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2, 3}}}), abs_ins);
......@@ -4910,7 +4896,7 @@ TEST_CASE(reducel1_dyn_test)
mm->add_return({sq_ins});
migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{3, 3}, {3, 5}, {4, 6, 5}, {5, 7, 6}};
options.map_dyn_input_dims["x"] = {{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}};
auto prog = migraphx::parse_onnx("reducel1_dyn_noaxes_test.onnx", options);
EXPECT(p == prog);
......@@ -5116,8 +5102,10 @@ TEST_CASE(reshape_test)
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
op.dims = reshape_dims;
mm->add_instruction(op, l0);
mm->add_instruction(op, l0);
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, c0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, c1);
auto prog = optimize_onnx("reshape_test.onnx");
EXPECT(p == prog);
......@@ -5806,17 +5794,17 @@ TEST_CASE(scatternd_dyn_test)
auto* mm = p.get_main_module();
// parameters with dynamic dimensions
auto l0 = mm->add_parameter(
"data", migraphx::shape{migraphx::shape::float_type, {{1, 3, 2}, {2, 2}, {2, 2}}});
"data", migraphx::shape{migraphx::shape::float_type, {{1, 3, {2}}, {2, 2}, {2, 2}}});
auto l1 = mm->add_parameter(
"indices", migraphx::shape{migraphx::shape::int64_type, {{2, 1, 2}, {1, 1}, {2, 2}}});
"indices", migraphx::shape{migraphx::shape::int64_type, {{2, 1, {2}}, {1, 1}, {2, 2}}});
auto l2 = mm->add_parameter(
"updates", migraphx::shape{migraphx::shape::float_type, {{2, 1, 2}, {1, 1}, {2, 2}}});
"updates", migraphx::shape{migraphx::shape::float_type, {{2, 1, {2}}, {1, 1}, {2, 2}}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_none"), l0, l1, l2);
mm->add_return({r});
migraphx::onnx_options options;
options.map_dyn_input_dims["data"] = {{1, 3, 2}, {2, 2}, {2, 2}};
options.map_dyn_input_dims["indices"] = {{2, 1, 2}, {1, 1}, {2, 2}};
options.map_dyn_input_dims["updates"] = {{2, 1, 2}, {1, 1}, {2, 2}};
options.map_dyn_input_dims["data"] = {{1, 3, {2}}, {2, 2}, {2, 2}};
options.map_dyn_input_dims["indices"] = {{2, 1, {2}}, {1, 1}, {2, 2}};
options.map_dyn_input_dims["updates"] = {{2, 1, {2}}, {1, 1}, {2, 2}};
auto prog = migraphx::parse_onnx("scatternd_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -5950,7 +5938,7 @@ TEST_CASE(sinh_dynamic_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{1, 10, 0};
migraphx::shape::dynamic_dimension dd{1, 10};
std::vector<migraphx::shape::dynamic_dimension> dyn_dims;
dyn_dims.push_back(dd);
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dyn_dims});
......@@ -6016,7 +6004,7 @@ TEST_CASE(slice_dyn_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{3, 3, 0}, {1, 3, 0}, {2, 2, 0}}});
"0", migraphx::shape{migraphx::shape::float_type, {{3, 3}, {1, 3}, {2, 2}}});
auto ret = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), l0);
mm->add_return({ret});
......@@ -6024,7 +6012,7 @@ TEST_CASE(slice_dyn_test)
migraphx::onnx_options options;
// Parser converts the dynamic input shape to static unless there is at least one non-fixed
// dynamic dimension. Slicing is not allowed along the non-fixed axis 1.
options.map_dyn_input_dims["0"] = {{3, 3, 0}, {1, 3, 0}, {2, 2, 0}};
options.map_dyn_input_dims["0"] = {{3, 3}, {1, 3}, {2, 2}};
auto prog = migraphx::parse_onnx("slice_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -6035,7 +6023,7 @@ TEST_CASE(slice_step_dyn_test)
// A slice command with non-default steps will have a "Step" instruction added in parsing.
// At the time of writing, Step doesn't support dynamic shape input.
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
EXPECT(test::throws([&] { migraphx::parse_onnx("slice_step_dyn_test.onnx", options); }));
}
......@@ -6044,7 +6032,7 @@ TEST_CASE(slice_reverse_dyn_test)
// A slice command with negative step on any axis will have a "Reverse" instruction added in
// parsing. At the time of writing, Reverse doesn't support dynamic shape input.
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
EXPECT(test::throws([&] { migraphx::parse_onnx("slice_reverse_dyn_test.onnx", options); }));
}
......@@ -6173,13 +6161,12 @@ TEST_CASE(softmax_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {4, 4, 0}}});
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {4, 4}}});
auto ret = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), l0);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = migraphx::parse_onnx("softmax_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -6389,10 +6376,9 @@ TEST_CASE(squeeze_unsqueeze_dyn_test)
auto* mm = p.get_main_module();
std::vector<int64_t> squeeze_axes{0, 2, 3, 5};
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 = mm->add_parameter(
"0",
auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{1, 1, 0}, {1, 4, 0}, {1, 1, 0}, {1, 1, 0}, {1, 4, 0}, {1, 1, 0}}});
{{1, 1}, {1, 4}, {1, 1}, {1, 1}, {1, 4}, {1, 1}}});
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), c0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
......@@ -6400,7 +6386,7 @@ TEST_CASE(squeeze_unsqueeze_dyn_test)
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("squeeze_unsqueeze_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -6694,14 +6680,13 @@ TEST_CASE(transpose_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}, {3, 3, 0}}});
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}, {3, 3}}});
std::vector<int64_t> perm{0, 3, 1, 2};
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
mm->add_return({t0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = migraphx::parse_onnx("transpose_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -6891,7 +6876,7 @@ TEST_CASE(variable_batch_user_input_test1)
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 2, 0};
options.default_dyn_dim_value = {2, 2};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
......@@ -6902,14 +6887,13 @@ TEST_CASE(variable_batch_user_input_test2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto l0 = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{2, 5}, {3, 3}, {16, 16}, {16, 16}}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 5, 0};
options.default_dyn_dim_value = {2, 5};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
......@@ -6920,14 +6904,13 @@ TEST_CASE(variable_batch_user_input_test3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto l0 = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{2, 5}, {3, 3}, {16, 16}, {16, 16}}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}};
options.map_dyn_input_dims["0"] = {{2, 5}, {3, 3}, {16, 16}, {16, 16}};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
......@@ -6955,7 +6938,7 @@ TEST_CASE(variable_batch_user_input_test5)
// Error using default_dim_value and default_dyn_dim_value
migraphx::onnx_options options;
options.default_dim_value = 2;
options.default_dyn_dim_value = {1, 2, 0};
options.default_dyn_dim_value = {1, 2};
EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); }));
}
......@@ -6964,7 +6947,7 @@ TEST_CASE(variable_batch_user_input_test6)
{
// Error using both map_dyn_input_dims and map_input_dims
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}};
options.map_dyn_input_dims["0"] = {{2, 5}, {3, 3}, {16, 16}, {16, 16}};
options.map_input_dims["0"] = {2, 3, 16, 16};
EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); }));
......@@ -7012,17 +6995,17 @@ TEST_CASE(where_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto lc = mm->add_parameter(
"c", migraphx::shape{migraphx::shape::bool_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}}});
"c", migraphx::shape{migraphx::shape::bool_type, {{1, 4}, {2, 2}, {2, 2}}});
auto lx = mm->add_parameter(
"x", migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}}});
"x", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}}});
auto ly = mm->add_parameter(
"y", migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}}});
"y", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}}});
auto r = mm->add_instruction(migraphx::make_op("where"), lc, lx, ly);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("where_dyn_test.onnx", options);
EXPECT(p == prog);
......@@ -7032,7 +7015,7 @@ TEST_CASE(where_mixed_test)
{
// mixture of static and dynamic input shapes is not supported
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
options.default_dyn_dim_value = {1, 4};
EXPECT(test::throws([&] { migraphx::parse_onnx("where_mixed_test.onnx", options); }));
}
......
......@@ -121,20 +121,16 @@ TEST_CASE(argmax_axis_outofbounds)
TEST_CASE(argmax_dyn0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}};
expect_shape(
migraphx::shape{migraphx::shape::int64_type, {{1, 4, 0}, {1, 1, 0}, {4, 4, 0}, {5, 5, 0}}},
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {{1, 4}, {1, 1}, {4, 4}, {5, 5}}},
migraphx::make_op("argmax", {{"axis", 1}}),
input);
}
TEST_CASE(argmax_dyn1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {4, 6, 0}, {4, 6, 0}}};
expect_shape(
migraphx::shape{migraphx::shape::int64_type, {{1, 4, 0}, {3, 3, 0}, {1, 1, 0}, {4, 6, 0}}},
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 6}, {4, 6}}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {{1, 4}, {3, 3}, {1, 1}, {4, 6}}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
}
......@@ -142,7 +138,7 @@ TEST_CASE(argmax_dyn1)
TEST_CASE(binary_dyn_static_error)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 4, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 1, 0}, {4, 4, 4}, {4, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 1}, {4, 4, {4}}, {4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("add"), a_shape, b_shape);
}
......@@ -216,13 +212,13 @@ TEST_CASE(broadcast_2in_not_matching_error)
TEST_CASE(broadcast_2in_dynamic_s0_error1)
{
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}, {2, 1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {2, 2}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), b_input, a_input);
}
TEST_CASE(broadcast_2in_dynamic_s0_error2)
{
std::vector<migraphx::shape::dynamic_dimension> dd{{4, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> dd{{4, 4}};
migraphx::shape a_input{migraphx::shape::float_type, dd};
migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
......@@ -231,9 +227,9 @@ TEST_CASE(broadcast_2in_dynamic_s0_error2)
TEST_CASE(broadcast_2in_static_dyn)
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {2, 2}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}, {2, 2}}},
migraphx::make_op("broadcast", {{"axis", 1}}),
a_input,
b_input);
......@@ -243,11 +239,11 @@ TEST_CASE(broadcast_2in_static_dyn)
TEST_CASE(broadcast_2in_dyn_s0_ndim_greater_than_1_error)
{
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {2, 2}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
}
TEST_CASE(convolution_shape)
TEST_CASE(conv_2d_0)
{
migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
......@@ -257,13 +253,19 @@ TEST_CASE(convolution_shape)
throws_shape(
migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input);
}
TEST_CASE(conv_2d_1)
{
migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape input2{migraphx::shape::float_type, {3, 3}};
migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraphx::make_op("convolution"), input2, weights2);
throws_shape(migraphx::make_op("convolution"), input2, weights);
}
// 1D convolution
TEST_CASE(conv_1d)
{
migraphx::shape output_1d{migraphx::shape::float_type, {4, 4, 1}};
migraphx::shape input_1d{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape weights_1d{migraphx::shape::float_type, {4, 3, 3}};
......@@ -272,12 +274,17 @@ TEST_CASE(convolution_shape)
migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input_1d,
weights_1d);
}
// channel numbers mismatch
weights_1d = {migraphx::shape::float_type, {4, 8, 3}};
TEST_CASE(conv_channel_mismatch)
{
migraphx::shape input_1d{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape weights_1d = {migraphx::shape::float_type, {4, 8, 3}};
throws_shape(migraphx::make_op("convolution"), input_1d, weights_1d);
}
// 3D convolution
TEST_CASE(conv_3D)
{
migraphx::shape output_3d{migraphx::shape::float_type, {4, 4, 1, 1, 1}};
migraphx::shape input_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
migraphx::shape weights_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
......@@ -289,93 +296,82 @@ TEST_CASE(convolution_shape)
weights_3d);
throws_shape(migraphx::make_op("convolution"), input_3d, weights_3d);
}
// dynamic batch
TEST_CASE(conv_dyn_batch)
{
migraphx::shape input_dyn_shape{migraphx::shape::float_type,
{{1, 100, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}};
{{1, 100}, {3, 3}, {5, 5}, {5, 5}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape{migraphx::shape::float_type,
{{
1,
100,
0,
},
{1, 1, 0},
{3, 3, 0},
{3, 3, 0}}};
{{1, 100}, {1, 1}, {3, 3}, {3, 3}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input_dyn_shape,
weights_shape);
}
// dynamic image
input_dyn_shape = {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {5, 20, 0}, {5, 20, 0}}};
weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
output_dyn_shape = {migraphx::shape::float_type,
{{
1,
1,
0,
},
{1, 1, 0},
{3, 18, 0},
{3, 18, 0}}};
TEST_CASE(conv_dyn_img)
{
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {3, 3}, {5, 20}, {5, 20}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {3, 18}, {3, 18}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input_dyn_shape,
weights_shape);
}
// dynamic weights
input_dyn_shape = {migraphx::shape::float_type, {1, 3, 10, 10}};
weights_shape = {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 4, 0}, {2, 4, 0}}};
output_dyn_shape = {migraphx::shape::float_type,
{{
1,
1,
0,
},
{1, 1, 0},
{7, 9, 0},
{7, 9, 0}}};
TEST_CASE(conv_dyn_weights)
{
migraphx::shape input_dyn_shape = {migraphx::shape::float_type, {1, 3, 10, 10}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 4}, {2, 4}}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {7, 9}, {7, 9}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input_dyn_shape,
weights_shape);
}
// dynamic img and weights
input_dyn_shape = {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {5, 20, 0}, {5, 20, 0}}};
weights_shape = {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 4, 0}, {2, 4, 0}}};
output_dyn_shape = {migraphx::shape::float_type,
{{
1,
1,
0,
},
{1, 1, 0},
{2, 19, 0},
{2, 19, 0}}};
TEST_CASE(conv_dyn_img_weights)
{
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {3, 3}, {5, 20}, {5, 20}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 4}, {2, 4}}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {2, 19}, {2, 19}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input_dyn_shape,
weights_shape);
}
// input attr shape mismatch
input_dyn_shape = {migraphx::shape::float_type,
{{1, 100, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}, {5, 5, 0}}};
weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3, 3}};
TEST_CASE(conv_attr_shape_mismatch)
{
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 100}, {3, 3}, {5, 5}, {5, 5}, {5, 5}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3, 3}};
throws_shape(migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input_dyn_shape,
weights_shape);
}
TEST_CASE(conv_autopad_dyn_batch)
{
// auto_pad dynamic batch
input_dyn_shape = {migraphx::shape::float_type, {{1, 10, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}};
weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
output_dyn_shape = {migraphx::shape::float_type, {{1, 10, 0}, {1, 1, 0}, {5, 5, 0}, {5, 5, 0}}};
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 10}, {3, 3}, {5, 5}, {5, 5}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 10}, {1, 1}, {5, 5}, {5, 5}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"stride", {1, 1}},
......@@ -383,12 +379,16 @@ TEST_CASE(convolution_shape)
{"padding_mode", migraphx::op::padding_mode_t::same_upper}}),
input_dyn_shape,
weights_shape);
}
TEST_CASE(conv_autopad_dyn_img)
{
// auto_pad dynamic img
input_dyn_shape = {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {5, 10, 0}, {5, 10, 0}}};
weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
output_dyn_shape = {migraphx::shape::float_type,
{{1, 1, 0}, {1, 1, 0}, {5, 10, 0}, {5, 10, 0}}};
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {3, 3}, {5, 10}, {5, 10}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {5, 10}, {5, 10}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"stride", {1, 1}},
......@@ -396,13 +396,15 @@ TEST_CASE(convolution_shape)
{"padding_mode", migraphx::op::padding_mode_t::same_upper}}),
input_dyn_shape,
weights_shape);
}
// auto_pad dynamic kernel
input_dyn_shape = {migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {10, 10, 0}, {10, 10, 0}}};
weights_shape = {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 4, 0}, {2, 4, 0}}};
output_dyn_shape = {migraphx::shape::float_type,
{{1, 1, 0}, {1, 1, 0}, {10, 10, 0}, {10, 10, 0}}};
TEST_CASE(conv_autopad_dyn_kernel)
{
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {3, 3}, {10, 10}, {10, 10}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 4}, {2, 4}}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {10, 10}, {10, 10}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"stride", {1, 1}},
......@@ -425,17 +427,24 @@ TEST_CASE(contiguous_shape)
TEST_CASE(contiguous_dyn_shape)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 2}}};
migraphx::shape s0{migraphx::shape::float_type, {{1, 4}, {2, 2, {2}}}};
expect_shape(s0, migraphx::make_op("contiguous"), s0);
}
TEST_CASE(contiguous_shape_scalar)
{
migraphx::shape output{migraphx::shape::float_type};
migraphx::shape output{migraphx::shape::float_type, {1}};
migraphx::shape input{migraphx::shape::float_type};
expect_shape(output, migraphx::make_op("contiguous"), input);
}
TEST_CASE(contiguous_shape_singleton_dim)
{
migraphx::shape output{migraphx::shape::float_type, {5, 1, 8}, {8, 8, 1}};
migraphx::shape input{migraphx::shape::float_type, {5, 1, 8}, {8, 4, 1}};
expect_shape(output, migraphx::make_op("contiguous"), input);
}
TEST_CASE(deconvolution_shape)
{
migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}};
......@@ -611,9 +620,9 @@ TEST_CASE(dot_4D_test)
TEST_CASE(dot_dyn_static_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {8, 8, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
......@@ -621,16 +630,16 @@ TEST_CASE(dot_dyn_static_test0)
TEST_CASE(dot_dyn_static_mismatch_error)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}};
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {3, 3}, {5, 5}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_dyn_dyn_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{5, 5, 0}, {6, 8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 8, 8}}},
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{5, 5}, {6, 8, {8}}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {6, 8, {8}}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
......@@ -638,9 +647,9 @@ TEST_CASE(dot_dyn_dyn_test0)
TEST_CASE(dot_dyn_dyn_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {4, 5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{4, 5, 5}, {6, 8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 8, 8}}},
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {4, 5, {5}}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{4, 5, {5}}, {6, 8, {8}}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {6, 8, {8}}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
......@@ -648,14 +657,14 @@ TEST_CASE(dot_dyn_dyn_test1)
TEST_CASE(dot_dyn_mismatch_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}, {5, 5, 0}}};
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_dyn_mismatch_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{4, 4, 0}, {5, 5, 0}, {2, 5, 0}}};
migraphx::shape s_m1{migraphx::shape::float_type, {{4, 4}, {5, 5}, {2, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
......@@ -690,12 +699,11 @@ TEST_CASE(flatten_shape)
TEST_CASE(flatten_dyn_axis0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {4, 4, 0}, {6, 6, 0}, {8, 8, 0}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1, 0}, {192, 768, 0}}},
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {6, 6}, {8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1}, {192, 768}}},
migraphx::make_op("flatten", {{"axis", 0}}),
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1, 0}, {192, 768, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1}, {192, 768}}},
migraphx::make_op("flatten", {{"axis", -4}}),
input);
}
......@@ -703,13 +711,13 @@ TEST_CASE(flatten_dyn_axis0)
TEST_CASE(flatten_dyn_axis1)
{
migraphx::shape input{migraphx::shape::float_type,
{{2, 2, 2}, {4, 4, 0}, {4, 6, 5}, {4, 6, 5}}};
{{2, 2, {2}}, {4, 4}, {4, 6, {5}}, {4, 6, {5}}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{2, 2, 2}, {4 * 4 * 4, 4 * 6 * 6, 0}}},
migraphx::shape{migraphx::shape::float_type, {{2, 2, {2}}, {4 * 4 * 4, 4 * 6 * 6}}},
migraphx::make_op("flatten", {{"axis", 1}}),
input);
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{2, 2, 2}, {4 * 4 * 4, 4 * 6 * 6, 0}}},
migraphx::shape{migraphx::shape::float_type, {{2, 2, {2}}, {4 * 4 * 4, 4 * 6 * 6}}},
migraphx::make_op("flatten", {{"axis", -3}}),
input);
}
......@@ -717,29 +725,25 @@ TEST_CASE(flatten_dyn_axis1)
TEST_CASE(flatten_dyn_axis2)
{
migraphx::shape input{migraphx::shape::float_type,
{{2, 2, 2}, {4, 4, 0}, {4, 6, 5}, {4, 6, 5}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{2 * 4, 2 * 4, 0}, {4 * 4, 6 * 6, 5 * 5}}},
{{2, 2, {2}}, {4, 4}, {4, 6, {5}}, {4, 6, {5}}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{2 * 4, 2 * 4}, {4 * 4, 6 * 6}}},
migraphx::make_op("flatten", {{"axis", 2}}),
input);
}
TEST_CASE(flatten_dyn_axis3)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {4, 4, 0}, {6, 6, 0}, {8, 8, 0}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1 * 4 * 6, 4 * 4 * 6, 0}, {8, 8, 0}}},
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {6, 6}, {8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1 * 4 * 6, 4 * 4 * 6}, {8, 8}}},
migraphx::make_op("flatten", {{"axis", 3}}),
input);
}
TEST_CASE(flatten_dyn_axis4)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {4, 4, 0}, {6, 6, 0}, {8, 8, 0}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{1 * 4 * 6 * 8, 4 * 4 * 6 * 8, 0}, {1, 1, 0}}},
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {6, 6}, {8, 8}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1 * 4 * 6 * 8, 4 * 4 * 6 * 8}, {1, 1}}},
migraphx::make_op("flatten", {{"axis", 4}}),
input);
}
......@@ -835,11 +839,11 @@ TEST_CASE(gather_dyn0)
{
// Insert dynamic index into dynamic shape
migraphx::shape input{migraphx::shape::float_type,
{{2, 3, 2}, {3, 4, 3}, {6, 9, 7}, {12, 14, 13}}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 7, 3}, {3, 3, 0}}};
{{2, 3, {2}}, {3, 4, {3}}, {6, 9, {7}}, {12, 14, {13}}}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 7, {3}}, {3, 3}}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{2, 3, 2}, {2, 7, 3}, {3, 3, 0}, {6, 9, 7}, {12, 14, 13}}},
{{2, 3, {2}}, {2, 7, {3}}, {3, 3}, {6, 9, {7}}, {12, 14, {13}}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
......@@ -849,11 +853,11 @@ TEST_CASE(gather_dyn1)
{
// Insert static index into dynamic shape
migraphx::shape input{migraphx::shape::float_type,
{{2, 3, 2}, {3, 4, 3}, {6, 9, 7}, {12, 14, 13}}};
{{2, 3, {2}}, {3, 4, {3}}, {6, 9, {7}}, {12, 14, {13}}}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{2, 3, 2}, {2, 2, 0}, {3, 3, 0}, {6, 9, 7}, {12, 14, 13}}},
{{2, 3, {2}}, {2, 2}, {3, 3}, {6, 9, {7}}, {12, 14, {13}}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
......@@ -863,14 +867,15 @@ TEST_CASE(gather_dyn2)
{
// Insert scalar (static) index into dynamic shape
migraphx::shape input{migraphx::shape::float_type,
{{2, 3, 2}, {3, 4, 3}, {6, 9, 7}, {12, 14, 13}}};
{{2, 3, {2}}, {3, 4, {3}}, {6, 9, {7}}, {12, 14, {13}}}};
std::vector<std::size_t> mins;
std::vector<std::size_t> maxes;
std::vector<std::size_t> opts;
std::vector<std::set<std::size_t>> opts;
migraphx::shape indices{migraphx::shape::int32_type, mins, maxes, opts};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {{2, 3, 2}, {6, 9, 7}, {12, 14, 13}}},
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{2, 3, {2}}, {6, 9, {7}}, {12, 14, {13}}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
......@@ -880,10 +885,10 @@ TEST_CASE(gather_dyn3)
{
// Insert dynamic index into static shape, axis 1
migraphx::shape input{migraphx::shape::float_type, {2, 3, 6, 12}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 3, 2}, {3, 4, 3}}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 3, {2}}, {3, 4, {3}}}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{2, 2, 0}, {2, 3, 2}, {3, 4, 3}, {6, 6, 0}, {12, 12, 0}}},
{{2, 2}, {2, 3, {2}}, {3, 4, {3}}, {6, 6}, {12, 12}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
......@@ -893,10 +898,10 @@ TEST_CASE(gather_dyn4)
{
// Insert dynamic index into static shape, axis 0
migraphx::shape input{migraphx::shape::float_type, {2, 3, 6, 12}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 3, 2}, {3, 4, 3}}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 3, {2}}, {3, 4, {3}}}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{2, 3, 2}, {3, 4, 3}, {3, 3, 0}, {6, 6, 0}, {12, 12, 0}}},
{{2, 3, {2}}, {3, 4, {3}}, {3, 3}, {6, 6}, {12, 12}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
......@@ -1432,13 +1437,13 @@ TEST_CASE(multibroadcast)
TEST_CASE(multibroadcast_2in_static_dyn0)
{
migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {4, 4, 4}, {4, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4}, {4, 4, {4}}, {4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}, {4, 4}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}, {4, 4}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
......@@ -1447,13 +1452,13 @@ TEST_CASE(multibroadcast_2in_static_dyn0)
TEST_CASE(multibroadcast_2in_static_dyn1)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8}, {6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8}, {6, 6}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8}, {6, 6}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
......@@ -1462,13 +1467,13 @@ TEST_CASE(multibroadcast_2in_static_dyn1)
TEST_CASE(multibroadcast_2in_static_dyn2)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8}, {6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8}, {6, 6}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(b)}}),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8}, {6, 6}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(b)}}),
b_shape,
a_shape);
......@@ -1478,7 +1483,7 @@ TEST_CASE(multibroadcast_2in_static_dyn_error0)
{
// doesn't match on first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 0}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3}, {6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
......@@ -1488,7 +1493,7 @@ TEST_CASE(multibroadcast_2in_static_dyn_error1)
{
// doesn't match on first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4}, {6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
......@@ -1498,7 +1503,7 @@ TEST_CASE(multibroadcast_2in_static_dyn_error2)
{
// doesn't match on first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 2, 0}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 2}, {6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
......@@ -1506,15 +1511,15 @@ TEST_CASE(multibroadcast_2in_static_dyn_error2)
TEST_CASE(multibroadcast_2in_dyn_dyn0)
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4}, {2, 4, {2}}, {2, 4}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, {2}}, {2, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 4, {2}}, {2, 4}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 4, {2}}, {2, 4}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
......@@ -1522,15 +1527,15 @@ TEST_CASE(multibroadcast_2in_dyn_dyn0)
TEST_CASE(multibroadcast_2in_dyn_dyn1)
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4}, {2, 4, {2}}, {2, 4}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, {2}}, {2, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 4, {2}}, {2, 4}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(a)}}),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 4, {2}}, {2, 4}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(a)}}),
b_shape,
a_shape);
......@@ -1539,9 +1544,9 @@ TEST_CASE(multibroadcast_2in_dyn_dyn1)
TEST_CASE(multibroadcast_2in_dyn_dyn_error0)
{
// max doesn't match on second dimension of a
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4}, {2, 4, {2}}, {2, 4}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 5, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 5, {2}}, {2, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
......@@ -1550,9 +1555,9 @@ TEST_CASE(multibroadcast_2in_dyn_dyn_error0)
TEST_CASE(multibroadcast_2in_dyn_dyn_error1)
{
// opt doesn't match on second dimension of a
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4}, {2, 4, {2}}, {2, 4}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 3}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, {3}}, {2, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
......@@ -1663,7 +1668,7 @@ TEST_CASE(nms_shape)
score_thres_s);
// use_dyn_output == true
output_s = {migraphx::shape::int64_type, {{0, 6, 0}, {3, 3, 0}}};
output_s = {migraphx::shape::int64_type, {{0, 6}, {3, 3}}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
......@@ -1674,9 +1679,9 @@ TEST_CASE(nms_shape)
score_thres_s);
// dynamic batches
boxes_s = {migraphx::shape::float_type, {{1, 3, 0}, {6, 6, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 3, 0}, {1, 1, 0}, {6, 6, 0}}};
output_s = {migraphx::shape::int64_type, {{0, 18, 0}, {3, 3, 0}}};
boxes_s = {migraphx::shape::float_type, {{1, 3}, {6, 6}, {4, 4}}};
scores_s = {migraphx::shape::float_type, {{1, 3}, {1, 1}, {6, 6}}};
output_s = {migraphx::shape::int64_type, {{0, 18}, {3, 3}}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
......@@ -1687,9 +1692,9 @@ TEST_CASE(nms_shape)
score_thres_s);
// dynamic num boxes
boxes_s = {migraphx::shape::float_type, {{1, 1, 0}, {6, 20, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 1, 0}, {1, 1, 0}, {6, 20, 0}}};
output_s = {migraphx::shape::int64_type, {{0, 20, 0}, {3, 3, 0}}};
boxes_s = {migraphx::shape::float_type, {{1, 1}, {6, 20}, {4, 4}}};
scores_s = {migraphx::shape::float_type, {{1, 1}, {1, 1}, {6, 20}}};
output_s = {migraphx::shape::int64_type, {{0, 20}, {3, 3}}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
......@@ -1709,9 +1714,9 @@ TEST_CASE(nms_shape)
score_thres_s);
// dynamic classes
boxes_s = {migraphx::shape::float_type, {{1, 1, 0}, {6, 6, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 1, 0}, {1, 3, 0}, {6, 6, 0}}};
output_s = {migraphx::shape::int64_type, {{0, 6, 0}, {3, 3, 0}}};
boxes_s = {migraphx::shape::float_type, {{1, 1}, {6, 6}, {4, 4}}};
scores_s = {migraphx::shape::float_type, {{1, 1}, {1, 3}, {6, 6}}};
output_s = {migraphx::shape::int64_type, {{0, 6}, {3, 3}}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
......@@ -1744,8 +1749,8 @@ TEST_CASE(nms_shape)
score_thres_s);
// dynamic mismatch batches
boxes_s = {migraphx::shape::float_type, {{1, 4, 0}, {6, 6, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{2, 8, 0}, {1, 1, 0}, {6, 6, 0}}};
boxes_s = {migraphx::shape::float_type, {{1, 4}, {6, 6}, {4, 4}}};
scores_s = {migraphx::shape::float_type, {{2, 8}, {1, 1}, {6, 6}}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
......@@ -1755,8 +1760,8 @@ TEST_CASE(nms_shape)
score_thres_s);
// dynamic mismatch num boxes
boxes_s = {migraphx::shape::float_type, {{1, 1, 0}, {6, 8, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 1, 0}, {1, 1, 0}, {3, 9, 0}}};
boxes_s = {migraphx::shape::float_type, {{1, 1}, {6, 8}, {4, 4}}};
scores_s = {migraphx::shape::float_type, {{1, 1}, {1, 1}, {3, 9}}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
......@@ -1767,7 +1772,7 @@ TEST_CASE(nms_shape)
// dynamic number of classes, fixed boxes_s, mismatch batches
boxes_s = {migraphx::shape::float_type, {1, 6, 4}};
scores_s = {migraphx::shape::float_type, {{1, 3, 0}, {1, 3, 0}, {6, 6, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 3}, {1, 3}, {6, 6}}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
......@@ -1777,7 +1782,7 @@ TEST_CASE(nms_shape)
score_thres_s);
// dynamic number of classes, fixed boxes_s, mismatch num boxes
boxes_s = {migraphx::shape::float_type, {1, 6, 4}};
scores_s = {migraphx::shape::float_type, {{1, 1, 0}, {1, 3, 0}, {4, 8, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 1}, {1, 3}, {4, 8}}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
......@@ -1803,19 +1808,17 @@ TEST_CASE(pad_shape1)
TEST_CASE(pad_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 2}, {3, 3, 0}, {3, 5, 0}, {3, 5, 0}}};
migraphx::shape output{migraphx::shape::float_type,
{{1, 4, 2}, {3, 3, 0}, {5, 7, 0}, {5, 7, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4, {2}}, {3, 3}, {3, 5}, {3, 5}}};
migraphx::shape output{migraphx::shape::float_type, {{1, 4, {2}}, {3, 3}, {5, 7}, {5, 7}}};
expect_shape(output, migraphx::make_op("pad", {{"pads", {0, 0, 1, 1, 0, 0, 1, 1}}}), input);
}
TEST_CASE(pad_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 2}, {3, 3, 0}, {3, 5, 5}, {3, 5, 5}}};
{{1, 4, {2}}, {3, 3}, {3, 5, {5}}, {3, 5, {5}}}};
migraphx::shape output{migraphx::shape::float_type,
{{1, 4, 2}, {3, 3, 0}, {5, 7, 7}, {5, 7, 7}}};
{{1, 4, {2}}, {3, 3}, {5, 7, {7}}, {5, 7, {7}}}};
expect_shape(output, migraphx::make_op("pad", {{"pads", {0, 0, 1, 1, 0, 0, 1, 1}}}), input);
}
......@@ -1873,8 +1876,7 @@ TEST_CASE(pooling_shape3)
TEST_CASE(pooling_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 3}, {3, 3, 3}, {3, 3, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3, {3}}, {3, 3, {3}}, {3, 3}}};
throws_shape(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {1}},
......@@ -1885,10 +1887,8 @@ TEST_CASE(pooling_dyn_shape0)
TEST_CASE(pooling_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 3}, {3, 3, 3}, {3, 3, 0}}};
migraphx::shape output{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 3}, {1, 1, 1}, {1, 1, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3, {3}}, {3, 3, {3}}, {3, 3}}};
migraphx::shape output{migraphx::shape::float_type, {{1, 4}, {3, 3}, {1, 1}, {1, 1}}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
......@@ -1900,10 +1900,8 @@ TEST_CASE(pooling_dyn_shape1)
TEST_CASE(pooling_dyn_shape2)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {5, 5, 0}, {3, 3, 3}, {3, 3, 0}}};
migraphx::shape output{migraphx::shape::float_type,
{{1, 4, 0}, {5, 5, 0}, {2, 2, 2}, {2, 2, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {5, 5}, {3, 3, {3}}, {3, 3}}};
migraphx::shape output{migraphx::shape::float_type, {{1, 4}, {5, 5}, {2, 2}, {2, 2}}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
......@@ -1917,9 +1915,8 @@ TEST_CASE(pooling_dyn_shape2)
TEST_CASE(pooling_dyn_shape3)
{
migraphx::shape input{migraphx::shape::float_type,
{{4, 4, 0}, {3, 3, 0}, {4, 12, 8}, {4, 12, 8}}};
migraphx::shape output{migraphx::shape::float_type,
{{4, 4, 0}, {3, 3, 0}, {2, 4, 3}, {2, 4, 3}}};
{{4, 4}, {3, 3}, {4, 12, {8}}, {4, 12, {8}}}};
migraphx::shape output{migraphx::shape::float_type, {{4, 4}, {3, 3}, {2, 4}, {2, 4}}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
......@@ -1932,9 +1929,8 @@ TEST_CASE(pooling_dyn_shape3)
TEST_CASE(pooling_dyn_shape4)
{
migraphx::shape input{migraphx::shape::float_type,
{{4, 4, 0}, {3, 3, 0}, {4, 12, 8}, {4, 12, 8}}};
migraphx::shape output{migraphx::shape::float_type,
{{4, 4, 0}, {3, 3, 0}, {3, 6, 4}, {3, 6, 4}}};
{{4, 4}, {3, 3}, {4, 12, {8}}, {4, 12, {8}}}};
migraphx::shape output{migraphx::shape::float_type, {{4, 4}, {3, 3}, {3, 6}, {3, 6}}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
......@@ -2018,6 +2014,62 @@ TEST_CASE(quant_dot_2args)
}
}
TEST_CASE(qlinear)
{
migraphx::shape scales{migraphx::shape::float_type, {2, 4}};
migraphx::shape input{migraphx::shape::float_type, {2, 4}};
migraphx::shape result{migraphx::shape::uint8_type, {2, 4}};
expect_shape(result, migraphx::make_op("quantizelinear"), input, scales);
}
TEST_CASE(qlinear_zeros)
{
migraphx::shape zeros{migraphx::shape::int8_type, {2, 4}};
migraphx::shape scales{migraphx::shape::float_type, {2, 4}};
migraphx::shape input{migraphx::shape::float_type, {2, 4}};
migraphx::shape result{migraphx::shape::int8_type, {2, 4}};
expect_shape(result, migraphx::make_op("quantizelinear"), input, scales, zeros);
}
TEST_CASE(qlinear_fp16)
{
migraphx::shape scales{migraphx::shape::half_type, {2, 4}};
migraphx::shape input{migraphx::shape::half_type, {2, 4}};
migraphx::shape result{migraphx::shape::uint8_type, {2, 4}};
expect_shape(result, migraphx::make_op("quantizelinear"), input, scales);
}
TEST_CASE(qlinear_mismatch_type)
{
migraphx::shape scales{migraphx::shape::int8_type, {2, 4}};
migraphx::shape input{migraphx::shape::float_type, {2, 4}};
throws_shape(migraphx::make_op("quantizelinear"), input, scales);
}
TEST_CASE(dqlinear)
{
migraphx::shape scales{migraphx::shape::float_type, {2, 4}};
migraphx::shape input{migraphx::shape::int8_type, {2, 4}};
migraphx::shape result{migraphx::shape::float_type, {2, 4}};
expect_shape(result, migraphx::make_op("dequantizelinear"), input, scales);
}
TEST_CASE(dqlinear_fp16)
{
migraphx::shape scales{migraphx::shape::half_type, {2, 4}};
migraphx::shape input{migraphx::shape::int8_type, {2, 4}};
migraphx::shape result{migraphx::shape::half_type, {2, 4}};
expect_shape(result, migraphx::make_op("dequantizelinear"), input, scales);
}
TEST_CASE(dqlinear_mismatch_type)
{
migraphx::shape zeros{migraphx::shape::float_type, {2, 4}};
migraphx::shape scales{migraphx::shape::float_type, {2, 4}};
migraphx::shape input{migraphx::shape::int8_type, {2, 4}};
throws_shape(migraphx::make_op("dequantizelinear"), input, scales, zeros);
}
template <class T>
void test_reduce_ops()
{
......@@ -2054,32 +2106,32 @@ template <class T>
void test_dyn_reduce_ops()
{
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>(
{{2, 3, 3}, {1, 1, 0}})},
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{2, 3, {3}}, {1, 1}})},
T{{-1}},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>(
{{1, 1, 0}, {2, 4, 4}})},
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {2, 4, {4}}})},
T{{0}},
input);
}
{
// Empty axis argument reduces all axes
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>(
{{1, 1, 0}, {1, 1, 0}})},
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {1, 1}})},
T{{}},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
throws_shape(T{{4}}, input);
}
}
......@@ -2107,7 +2159,7 @@ TEST_CASE(reshape_shape)
}
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0}, {3, 2}})
{
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
......@@ -2131,8 +2183,7 @@ TEST_CASE(reshape_shape)
TEST_CASE(reshape_dyn_shape)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
for(auto&& new_shape : std::vector<std::vector<int64_t>>{
{-1, 1, 1, 24}, {0, 8, 3, 1}, {-1, 3, 4, 2}, {0, 2, 4, 3}})
{
......@@ -2146,7 +2197,7 @@ TEST_CASE(reshape_dyn_shape)
else
{
std::size_t d = new_shape[i];
out_dyn_dims.push_back({d, d, 0});
out_dyn_dims.push_back({d, d});
}
}
migraphx::shape output{migraphx::shape::float_type, out_dyn_dims};
......@@ -2156,24 +2207,21 @@ TEST_CASE(reshape_dyn_shape)
TEST_CASE(reshape_multiple_non_fixed_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {10, 20, 0}, {1, 1, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {10, 20}, {1, 1}}};
std::vector<int64_t> new_shape = {0, 1, 0, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_fixed_ele_not_matching_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {10, 10, 0}, {1, 1, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {10, 10}, {1, 1}}};
std::vector<int64_t> new_shape = {0, 1, 5, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_non_fixed_not_matching_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
std::vector<int64_t> new_shape = {2, 1, 1, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
......@@ -2393,28 +2441,28 @@ TEST_CASE(slice_shape)
TEST_CASE(slice_dyn_shape0)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 3, 0}, {7, 7, 0}, {2, 3, 0}}};
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}};
// Slice axis 1 to size 4-1=3
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3, 0}, {3, 3, 0}, {2, 3, 0}}},
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3}, {3, 3}, {2, 3}}},
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}),
input);
}
TEST_CASE(slice_dyn_shape1)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 3, 0}, {7, 7, 0}, {2, 3, 0}}};
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}};
// Slice axis 1 with negative index
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3, 0}, {2, 2, 0}, {2, 3, 0}}},
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3}, {2, 2}, {2, 3}}},
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {-4}}}),
input);
}
TEST_CASE(slice_dyn_shape2)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 3, 0}, {7, 7, 0}, {2, 3, 0}}};
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}};
// Sliced range max bigger than dimension; is clipped
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3, 0}, {6, 6, 0}, {2, 3, 0}}},
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3}, {6, 6}, {2, 3}}},
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {10}}}),
input);
}
......@@ -2423,11 +2471,11 @@ TEST_CASE(slice_dyn_shape3)
{
// TODO: When variable dimension slicing is allowed, Slice to a size smaller than min.
// Until then, this action is an error.
migraphx::shape input{migraphx::shape::int32_type, {{2, 3, 0}, {7, 8, 0}, {2, 3, 0}}};
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 8}, {2, 3}}};
throws_shape(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}),
input);
// clang-format off
// expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3, 0}, {1, 1, 0}, {2, 3, 0}}},
// expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3}, {1, 1}, {2, 3}}},
// migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}),
// input);
// clang-format on
......@@ -2435,10 +2483,10 @@ TEST_CASE(slice_dyn_shape3)
TEST_CASE(slice_dyn_shape4)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 2, 0}, {7, 7, 0}, {2, 3, 0}}};
migraphx::shape input{migraphx::shape::int32_type, {{2, 2}, {7, 7}, {2, 3}}};
// Slice multiple axes: axis 0 to size 2-1=1 and axis 1 to size 4-1=3
expect_shape(
migraphx::shape{migraphx::shape::int32_type, {{1, 1, 0}, {3, 3, 0}, {2, 3, 0}}},
migraphx::shape{migraphx::shape::int32_type, {{1, 1}, {3, 3}, {2, 3}}},
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 1}}, {"ends", {2, 4}}}),
input);
}
......@@ -2446,7 +2494,7 @@ TEST_CASE(slice_dyn_shape4)
TEST_CASE(slice_dyn_shape5)
{
// Axis out of range.
migraphx::shape input{migraphx::shape::int32_type, {{2, 2, 0}, {7, 7, 0}, {2, 3, 0}}};
migraphx::shape input{migraphx::shape::int32_type, {{2, 2}, {7, 7}, {2, 3}}};
throws_shape(
migraphx::make_op("slice", {{"axes", {0, 20}}, {"starts", {1, 1}}, {"ends", {2, 4}}}),
input);
......@@ -2456,15 +2504,13 @@ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(softmax_dyn0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}};
expect_shape(input, migraphx::make_op("softmax", {{"axis", 0}}), input);
}
TEST_CASE(softmax_dyn1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {4, 6, 0}, {5, 8, 6}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 1}, {3, 3}, {4, 6}, {5, 8, {6}}}};
expect_shape(input, migraphx::make_op("softmax", {{"axis", 0}}), input);
}
......@@ -2629,7 +2675,7 @@ TEST_CASE(test_gathernd_dynamic0)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8}};
migraphx::shape ds{dtype, b};
int batch_dims(1);
......@@ -2642,7 +2688,7 @@ TEST_CASE(test_gathernd_dynamic1)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2}};
migraphx::shape ds{dtype, b};
int batch_dims(1);
......@@ -2655,7 +2701,7 @@ TEST_CASE(test_gathernd_dynamic2)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 1}};
migraphx::shape ds{dtype, {{2, 3, 3}, {5, 6, 5}, {6, 9, 7}, {7, 8, 8}}};
migraphx::shape ds{dtype, {{2, 3, {3}}, {5, 6, {5}}, {6, 9, {7}}, {7, 8, {8}}}};
int batch_dims(3);
throws_shape(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
......@@ -2667,10 +2713,10 @@ TEST_CASE(test_gathernd_dynamic3)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {1}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2}};
migraphx::shape ds{dtype, b};
migraphx::shape::dynamic_dimension ddout{1, 1, 0};
migraphx::shape::dynamic_dimension ddout{1, 1};
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd"), ds, is);
}
......@@ -2681,10 +2727,10 @@ TEST_CASE(test_gathernd_dynamic4)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 2}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}, {2, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2}, {2, 2}};
migraphx::shape ds{dtype, b};
migraphx::shape::dynamic_dimension ddout{2, 2, 0};
migraphx::shape::dynamic_dimension ddout{2, 2};
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd"), ds, is);
}
......@@ -2696,10 +2742,10 @@ TEST_CASE(test_gathernd_dynamic5)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 1}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}, {2, 2, 0}, {2, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2}, {2, 2}, {2, 2}};
migraphx::shape ds{dtype, b};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 2, 0}, {2, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 2}, {2, 2}};
int batch_dims(1);
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
......@@ -2711,11 +2757,11 @@ TEST_CASE(test_gathernd_dynamic6)
// index dynamic shape, data static
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
std::vector<migraphx::shape::dynamic_dimension> b{{2, 3, 0}, {1, 1, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 3}, {1, 1}};
migraphx::shape is{itype, b};
migraphx::shape ds{dtype, {2, 2, 2}};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 3, 0}, {2, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 3}, {2, 2}};
int batch_dims(1);
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
......@@ -2726,7 +2772,7 @@ TEST_CASE(test_gathernd_dynamic6a)
// indices with non-fixed dynamic dimension k
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}, {1, 3, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2}, {1, 3}};
migraphx::shape is{itype, b};
migraphx::shape ds{dtype, {2, 2, 2}};
......@@ -2740,12 +2786,12 @@ TEST_CASE(test_gathernd_dynamic7)
// index and data both dynamic shapes
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
std::vector<migraphx::shape::dynamic_dimension> idyn{{2, 5, 0}, {1, 1, 0}};
std::vector<migraphx::shape::dynamic_dimension> idyn{{2, 5}, {1, 1}};
migraphx::shape is{itype, idyn};
std::vector<migraphx::shape::dynamic_dimension> bdyn{{1, 2, 0}, {1, 2, 0}, {1, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> bdyn{{1, 2}, {1, 2}, {1, 2}};
migraphx::shape ds{dtype, bdyn};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 5, 0}, {1, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 5}, {1, 2}};
int batch_dims(1);
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
......@@ -2758,10 +2804,10 @@ TEST_CASE(test_gathernd_dynamic8)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 5, 1}};
std::vector<migraphx::shape::dynamic_dimension> b{{6, 7, 7}, {3, 3, 0}, {1, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{6, 7, {7}}, {3, 3}, {1, 4}};
migraphx::shape ds{dtype, b};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 2, 0}, {5, 5, 0}, {1, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 2}, {5, 5}, {1, 4}};
int batch_dims(1);
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
......@@ -2841,7 +2887,7 @@ TEST_CASE(test_scatternd_dyn0)
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4}};
migraphx::shape is{itype, {4, 13}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape::dynamic_dimension dd{4, 4};
migraphx::shape us{dtype, {dd}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
}
......@@ -2853,7 +2899,7 @@ TEST_CASE(test_scatternd_dyn1)
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape::dynamic_dimension dd{4, 4};
migraphx::shape us{dtype, {dd}};
expect_shape(ds, migraphx::make_op("scatternd_none"), ds, is, us);
}
......@@ -2866,7 +2912,7 @@ TEST_CASE(test_scatternd_dyn2)
migraphx::shape ds{dtype, {2, 3, 1, 4}, {0, 1, 1, 0}};
migraphx::shape ds_std{dtype, {2, 3, 1, 4}};
migraphx::shape is{itype, {4, 4}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape::dynamic_dimension dd{4, 4};
migraphx::shape us{dtype, {dd}};
expect_shape(ds_std, migraphx::make_op("scatternd_none"), ds, is, us);
}
......@@ -2878,7 +2924,7 @@ TEST_CASE(test_scatternd_dyn3)
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 3, 1, 4}};
migraphx::shape is{itype, {4, 4}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape::dynamic_dimension dd{4, 4};
migraphx::shape us{dtype, {dd}};
expect_shape(ds, migraphx::make_op("scatternd_none"), ds, is, us);
}
......@@ -2889,7 +2935,7 @@ TEST_CASE(test_scatternd_dyn4)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 3, 1, 4}};
migraphx::shape::dynamic_dimension dd{4, 5, 0};
migraphx::shape::dynamic_dimension dd{4, 5};
migraphx::shape is{itype, {dd, dd}};
migraphx::shape us{dtype, {dd}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
......@@ -2901,8 +2947,8 @@ TEST_CASE(test_scatternd_dyn5)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 3, 1, 4}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape::dynamic_dimension dbad{2, 3, 0};
migraphx::shape::dynamic_dimension dd{4, 4};
migraphx::shape::dynamic_dimension dbad{2, 3};
migraphx::shape is{itype, {dd, dd}};
migraphx::shape us{dtype, {dbad}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
......@@ -2924,12 +2970,11 @@ TEST_CASE(test_squeeze_all)
TEST_CASE(test_squeeze_dyn)
{
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {3, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {1, 1}, {3, 3}, {1, 1}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4}, {1, 1}, {3, 3}, {3, 3}}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1);
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
migraphx::shape s3{migraphx::shape::float_type, {{1, 4}, {3, 3}, {3, 3}}};
expect_shape(s3, migraphx::make_op("squeeze"), s1);
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
......@@ -2937,12 +2982,11 @@ TEST_CASE(test_squeeze_dyn)
TEST_CASE(test_squeeze_dyn_neg_axes)
{
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {3, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {1, 1}, {3, 3}, {1, 1}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4}, {1, 1}, {3, 3}, {3, 3}}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
migraphx::shape s3{migraphx::shape::float_type, {{1, 4}, {3, 3}, {3, 3}}};
expect_shape(s3, migraphx::make_op("squeeze", {{"axes", {-2, -4}}}), s1);
}
......@@ -2989,12 +3033,11 @@ TEST_CASE(test_unsqueeze)
TEST_CASE(test_unsqueeze_dyn)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, {3}}, {2, 5}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, {3}}, {2, 5}, {1, 1}, {3, 3}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, {3}}, {2, 5}, {1, 1}, {3, 3}, {1, 1}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {2, 4}}}), s1);
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2, 4}}, {"steps", {2}}}), s1);
......@@ -3002,12 +3045,11 @@ TEST_CASE(test_unsqueeze_dyn)
TEST_CASE(test_unsqueeze_dyn_neg_axes)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, {3}}, {2, 5}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, {3}}, {2, 5}, {1, 1}, {3, 3}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, {3}}, {2, 5}, {1, 1}, {3, 3}, {1, 1}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {-1, -3}}}), s1);
}
......@@ -3170,16 +3212,16 @@ TEST_CASE(transpose_shape)
TEST_CASE(transpose_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}}};
migraphx::shape output{migraphx::shape::float_type, {{2, 2, 0}, {1, 4, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {2, 2}}};
migraphx::shape output{migraphx::shape::float_type, {{2, 2}, {1, 4}}};
expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {1, 0}}}), input);
}
TEST_CASE(transpose_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}};
migraphx::shape output{migraphx::shape::float_type, {{4, 4, 0}, {4, 4, 0}, {1, 4, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {4, 4}}};
migraphx::shape output{migraphx::shape::float_type, {{4, 4}, {4, 4}, {1, 4}}};
expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), input);
}
......@@ -3236,8 +3278,8 @@ TEST_CASE(where_broadcast_input)
TEST_CASE(where_dyn_input0)
{
// dynamic shapes not the same
migraphx::shape s1{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3, 0}, {2, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{2, 3}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3}, {2, 3}}};
migraphx::shape s3{migraphx::shape::bool_type, {2, 2}};
throws_shape(migraphx::make_op("where"), s3, s1, s2);
}
......@@ -3246,7 +3288,7 @@ TEST_CASE(where_dyn_input1)
{
// mixed static/dynamic inputs (not allowed)
migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {2, 1}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 2}, {2, 2}}};
migraphx::shape s3{migraphx::shape::bool_type, {2, 2}, {2, 1}};
throws_shape(migraphx::make_op("where"), s3, s1, s2);
}
......@@ -3254,18 +3296,18 @@ TEST_CASE(where_dyn_input1)
TEST_CASE(where_dyn_input2)
{
// dynamic shapes
migraphx::shape s1{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s3{migraphx::shape::bool_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{2, 3}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3}, {3, 3}}};
migraphx::shape s3{migraphx::shape::bool_type, {{2, 3}, {3, 3}}};
expect_shape(s2, migraphx::make_op("where"), s3, s1, s2);
}
TEST_CASE(where_dyn_input3)
{
// dynamic shapes, predicate shape is different
migraphx::shape s1{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s3{migraphx::shape::bool_type, {{2, 3, 0}, {3, 4, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{2, 3}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3}, {3, 3}}};
migraphx::shape s3{migraphx::shape::bool_type, {{2, 3}, {3, 4}}};
throws_shape(migraphx::make_op("where"), s3, s1, s2);
}
......@@ -3318,9 +3360,9 @@ TEST_CASE(test_concat)
TEST_CASE(test_dyn_concat)
{
migraphx::shape sx{migraphx::shape::float_type, {{1, 3, 3}, {4, 4}, {1, 5, 5}, {6, 6}}};
migraphx::shape sy{migraphx::shape::float_type, {{1, 3, 3}, {4, 4}, {1, 4, 4}, {6, 6}}};
migraphx::shape sout{migraphx::shape::float_type, {{1, 3, 3}, {4, 4, 0}, {2, 9, 0}, {6, 6}}};
migraphx::shape sx{migraphx::shape::float_type, {{1, 3, {3}}, {4, 4}, {1, 5, {5}}, {6, 6}}};
migraphx::shape sy{migraphx::shape::float_type, {{1, 3, {3}}, {4, 4}, {1, 4, {4}}, {6, 6}}};
migraphx::shape sout{migraphx::shape::float_type, {{1, 3, {3}}, {4, 4}, {2, 9}, {6, 6}}};
expect_shape(sout, migraphx::make_op("concat", {{"axis", 2}}), sx, sy);
......@@ -3328,7 +3370,7 @@ TEST_CASE(test_dyn_concat)
throws_shape(migraphx::make_op("concat", {{"axis", 4}}), sx, sy);
// rank doesn't match
migraphx::shape srank{migraphx::shape::int64_type, {{1, 3, 3}, {4, 4}}};
migraphx::shape srank{migraphx::shape::int64_type, {{1, 3, {3}}, {4, 4}}};
throws_shape(migraphx::make_op("concat", {{"axis", 0}}), sx, srank);
// non-matching dimension 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment