Unverified Commit 9fee7233 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge pull request #2019 from ROCmSoftwarePlatform/rel57_workitems

parents 0bc60894 97cc1dfc
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/gpu/config.hpp> #include <migraphx/gpu/config.hpp>
#include <migraphx/gpu/code_object_op.hpp> #include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/tuning_config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -36,16 +37,19 @@ struct module; ...@@ -36,16 +37,19 @@ struct module;
namespace gpu { namespace gpu {
MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m); MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m);
MIGRAPHX_GPU_EXPORT code_object_op compile_mlir(const context& ctx, MIGRAPHX_GPU_EXPORT code_object_op compile_mlir(const context& ctx,
module m, module m,
const std::vector<instruction_ref>& inputs); const std::vector<instruction_ref>& inputs,
const value& solution);
MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m, MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
instruction_ref ins, instruction_ref ins,
code_object_op co, code_object_op co,
const std::vector<instruction_ref>& inputs); const std::vector<instruction_ref>& inputs);
MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(module m,
const std::vector<shape>& inputs);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_TUNING_CONFIG_HPP
#define MIGRAPHX_GUARD_GPU_TUNING_CONFIG_HPP
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct tuning_config
{
value problem;
std::vector<value> solutions;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_TUNING_CONFIG_HPP
...@@ -36,11 +36,12 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -36,11 +36,12 @@ struct mlir_compiler : compiler<mlir_compiler>
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; } operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const compiler_replace
compile(context& ctx, instruction_ref ins, const operation&, const value& solution) const
{ {
auto* smod = ins->module_inputs().front(); auto* smod = ins->module_inputs().front();
assert(smod->get_parameter_names().size() == ins->inputs().size() - 1); assert(smod->get_parameter_names().size() == ins->inputs().size() - 1);
return insert(compile_mlir(ctx, *smod, ins->inputs())); return insert(compile_mlir(ctx, *smod, ins->inputs(), solution));
} }
compiler_replace insert(code_object_op co) const compiler_replace insert(code_object_op co) const
...@@ -50,6 +51,16 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -50,6 +51,16 @@ struct mlir_compiler : compiler<mlir_compiler>
m.replace_instruction(ins, mlir); m.replace_instruction(ins, mlir);
}}; }};
} }
optional<tuning_config>
get_tuning_config(context&, instruction_ref ins, const operation&, bool exhaustive) const
{
if(not exhaustive)
return nullopt;
auto shapes = to_shapes(ins->inputs());
auto* smod = ins->module_inputs().front();
return get_tuning_config_mlir(*smod, shapes);
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -72,7 +72,7 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -72,7 +72,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
hip_compile_options options; hip_compile_options options;
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(normalize_permutation(inputs));
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs); auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(ctx, axis, options.virtual_inputs); auto vec = vectorize::elements(ctx, axis, options.virtual_inputs);
......
...@@ -84,7 +84,7 @@ static shape get_reduced_shape(const shape& s, const std::vector<T>& axes) ...@@ -84,7 +84,7 @@ static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
std::fill(lens.begin(), lens.end(), 1); std::fill(lens.begin(), lens.end(), 1);
for(const auto& axis : axes) for(const auto& axis : axes)
lens[axis] = s.lens()[axis]; lens[axis] = s.lens()[axis];
return shape{s.type(), lens}; return s.with_lens(lens);
} }
template <class T> template <class T>
...@@ -93,7 +93,7 @@ static shape get_output_shape(const shape& s, const std::vector<T>& axes) ...@@ -93,7 +93,7 @@ static shape get_output_shape(const shape& s, const std::vector<T>& axes)
auto lens = s.lens(); auto lens = s.lens();
for(const auto& axis : axes) for(const auto& axis : axes)
lens[axis] = 1; lens[axis] = 1;
return shape{s.type(), lens}; return s.with_lens(lens);
} }
template <class ReduceLens> template <class ReduceLens>
...@@ -228,7 +228,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -228,7 +228,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto virtual_inputs = inputs; auto virtual_inputs = inputs;
virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes)); virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes));
virtual_inputs.push_back(get_output_shape(inputs.front(), axes)); virtual_inputs.push_back(get_output_shape(inputs.front(), axes));
virtual_inputs = reduce_dims(virtual_inputs); virtual_inputs = reduce_dims(normalize_permutation(virtual_inputs));
auto reduce_output_shape = virtual_inputs.back(); auto reduce_output_shape = virtual_inputs.back();
virtual_inputs.pop_back(); virtual_inputs.pop_back();
auto reduction_shape = virtual_inputs.back(); auto reduction_shape = virtual_inputs.back();
......
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/perfdb.hpp> #include <migraphx/gpu/perfdb.hpp>
#include <migraphx/gpu/tuning_config.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <deque> #include <deque>
...@@ -134,6 +135,10 @@ using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockD ...@@ -134,6 +135,10 @@ using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockD
using mlir_pass_manager = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy); using mlir_pass_manager = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy);
using mlir_tuning_table = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningTable, using mlir_tuning_table = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningTable,
mlirRockTuningTableDestroy); mlirRockTuningTableDestroy);
using mlir_tuning_space = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningSpace,
mlirRockTuningSpaceDestroy);
using mlir_tuning_param = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningParam,
mlirRockTuningParamDestroy);
std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; } std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; }
...@@ -616,18 +621,30 @@ struct mlir_program ...@@ -616,18 +621,30 @@ struct mlir_program
} }
} }
code_object_op compile() MIGRAPHX_TIDY_CONST void run_high_level_pipeline() MIGRAPHX_TIDY_CONST
{ {
mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())};
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm_front.get()); mlirMIGraphXAddHighLevelPipeline(pm_front.get());
mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get())); mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get()));
}
// 2nd pipeline to call void run_backend_pipeline() MIGRAPHX_TIDY_CONST
get_module_tuned(); {
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str()); mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str());
mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get())); mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get()));
}
code_object_op compile(const value& solution) MIGRAPHX_TIDY_CONST
{
// 1st pipeline to call
run_high_level_pipeline();
if(solution.is_null())
get_module_tuned();
else
set_tuning(solution);
// 2nd pipeline to call
run_backend_pipeline();
code_object_op op{}; code_object_op op{};
op.symbol_name = sym_name; op.symbol_name = sym_name;
...@@ -658,6 +675,33 @@ struct mlir_program ...@@ -658,6 +675,33 @@ struct mlir_program
MIGRAPHX_THROW("Failed to compile mlir program"); MIGRAPHX_THROW("Failed to compile mlir program");
} }
void set_tuning(const value& v)
{
auto str = v.to<std::string>();
// We need to make a copy of the buffer since mlirRockTuningSetFromStr may modify the string
std::vector<char> buffer(str.begin(), str.end());
buffer.push_back(0);
if(not mlirRockTuningSetFromStr(mmodule.get(), buffer.data()))
MIGRAPHX_THROW("Failed setting tuning key: " + str);
}
tuning_config get_tuning_config() MIGRAPHX_TIDY_CONST
{
tuning_config tc;
run_high_level_pipeline();
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get())};
for(auto i : range(mlirRockTuningGetNumParamsFull(params.get())))
{
mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get()))
MIGRAPHX_THROW("Incorrect mlir tuning parameter: " + std::to_string(i));
tc.solutions.push_back(std::string{mlirRockTuningGetParamStr(param.get())});
}
mlir_tuning_table tuning_table{mlirRockTuningTableCreate()};
tc.problem = std::string{mlirRockTuningGetKey(tuning_table.get(), mmodule.get())};
return tc;
}
std::string get_tune_params(bool xdlops) const { return get_mlir_perf_for_conv(pp, xdlops); } std::string get_tune_params(bool xdlops) const { return get_mlir_perf_for_conv(pp, xdlops); }
// This function appends to tuning cfg file that could be // This function appends to tuning cfg file that could be
...@@ -749,14 +793,14 @@ std::string dump_mlir(const module& m) ...@@ -749,14 +793,14 @@ std::string dump_mlir(const module& m)
return mlir_print(&mlirOperationPrint, mod_op); return mlir_print(&mlirOperationPrint, mod_op);
} }
void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs) void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
{ {
auto names = m.get_parameter_names(); auto names = m.get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
for(auto i : range(names.size())) for(auto i : range(names.size()))
{ {
const auto& name = names[i]; const auto& name = names[i];
const auto& input = inputs[i]->get_shape(); const auto& input = inputs[i];
auto param = m.get_parameter(name); auto param = m.get_parameter(name);
if(input.standard()) if(input.standard())
continue; continue;
...@@ -794,9 +838,12 @@ void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs) ...@@ -794,9 +838,12 @@ void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs)
} }
} }
code_object_op compile_mlir(const context&, module m, const std::vector<instruction_ref>& inputs) code_object_op compile_mlir(const context&,
module m,
const std::vector<instruction_ref>& inputs,
const value& solution)
{ {
adjust_param_shapes(m, inputs); adjust_param_shapes(m, to_shapes(inputs));
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace) if(trace)
...@@ -808,8 +855,9 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct ...@@ -808,8 +855,9 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct
auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace) if(trace)
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
auto co = mp.compile(); auto co = mp.compile(solution);
co.output = m.get_output_shapes().front(); co.expected_inputs = to_shapes(inputs);
co.output = m.get_output_shapes().front();
return co; return co;
} }
...@@ -829,6 +877,16 @@ instruction_ref insert_mlir(module& m, ...@@ -829,6 +877,16 @@ instruction_ref insert_mlir(module& m,
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
} }
tuning_config get_tuning_config_mlir(module m, const std::vector<shape>& inputs)
{
adjust_param_shapes(m, inputs);
mlir_program mp;
mp.find_target();
mp.parse(m);
return mp.get_tuning_config();
}
#else #else
std::string dump_mlir(const module&) { return {}; } std::string dump_mlir(const module&) { return {}; }
...@@ -840,11 +898,11 @@ void use(T&) ...@@ -840,11 +898,11 @@ void use(T&)
// Disabling clang-tidy warning on non-real useage. // Disabling clang-tidy warning on non-real useage.
// NOLINTBEGIN(performance-unnecessary-value-param) // NOLINTBEGIN(performance-unnecessary-value-param)
code_object_op compile_mlir(const context&, module, const std::vector<instruction_ref>&) code_object_op
compile_mlir(const context&, module, const std::vector<instruction_ref>&, const value&)
{ {
return {}; return {};
} }
// NOLINTEND(performance-unnecessary-value-param)
instruction_ref instruction_ref
// cppcheck-suppress funcArgNamesDifferent // cppcheck-suppress funcArgNamesDifferent
...@@ -854,6 +912,9 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins ...@@ -854,6 +912,9 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
return m.end(); return m.end();
} }
tuning_config get_tuning_config_mlir(module, const std::vector<shape>&) { return {}; }
// NOLINTEND(performance-unnecessary-value-param)
#endif #endif
} // namespace gpu } // namespace gpu
......
...@@ -75,7 +75,9 @@ namespace gpu { ...@@ -75,7 +75,9 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
#endif
struct id_pass struct id_pass
{ {
...@@ -136,7 +138,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -136,7 +138,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}), enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{}, dead_code_elimination{},
#ifndef _WIN32
enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}), enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),
#endif
dead_code_elimination{}, dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}), enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -52,7 +52,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -52,7 +52,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto x_type = args[0]->get_shape().type(); auto x_type = args[0]->get_shape().type();
// unsqueeze tensors of shape (C) to broadcast correctly // unsqueeze tensors of shape (C) to broadcast correctly
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = auto scale_unsqueeze =
...@@ -64,11 +63,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -64,11 +63,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto var_unsqueeze = auto var_unsqueeze =
info.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), args[4]); info.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps); auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto mul0 = info.add_broadcastable_binary_op("mul", scale_unsqueeze, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze); return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
} }
}; };
......
...@@ -36,7 +36,7 @@ endfunction() ...@@ -36,7 +36,7 @@ endfunction()
function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR) function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME}) set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC}) add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
target_link_libraries(${NAME} migraphx_c migraphx) target_link_libraries(${NAME} migraphx_c)
target_include_directories(${NAME} PUBLIC ../include) target_include_directories(${NAME} PUBLIC ../include)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR}) add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME}) add_dependencies(tests ${NAME})
......
...@@ -99,7 +99,7 @@ TEST_CASE(run_sigmoid_custom_op) ...@@ -99,7 +99,7 @@ TEST_CASE(run_sigmoid_custom_op)
EXPECT(bool{result == migraphx::argument(s, expected_result.data())}); EXPECT(bool{result == migraphx::argument(s, expected_result.data())});
} }
extern "C" void migraphx_test_private_disable_exception_catch(bool b); extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool);
TEST_CASE(run_sigmoid_with_incorrect_shape) TEST_CASE(run_sigmoid_with_incorrect_shape)
{ {
......
...@@ -34,7 +34,6 @@ TEST_CASE(load_and_run) ...@@ -34,7 +34,6 @@ TEST_CASE(load_and_run)
auto shapes_before = p.get_output_shapes(); auto shapes_before = p.get_output_shapes();
migraphx::compile_options options; migraphx::compile_options options;
options.set_offload_copy(); options.set_offload_copy();
options.set_exhaustive_tune_flag();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
auto shapes_after = p.get_output_shapes(); auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 1); CHECK(shapes_before.size() == 1);
......
...@@ -84,7 +84,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir) ...@@ -84,7 +84,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir)
inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front())); inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front()));
migraphx::gpu::context ctx; migraphx::gpu::context ctx;
migraphx::gpu::insert_mlir(*mm, mm->end(), compile_mlir(ctx, mmlir, inputs), inputs); migraphx::gpu::insert_mlir(*mm, mm->end(), compile_mlir(ctx, mmlir, inputs, {}), inputs);
return p; return p;
} }
......
...@@ -384,7 +384,7 @@ bool throws(F f, const std::string& msg = "") ...@@ -384,7 +384,7 @@ bool throws(F f, const std::string& msg = "")
} }
template <class T, class U> template <class T, class U>
auto near(T px, U py, double ptol = 1e-6f) auto within_abs(T px, U py, double ptol = 1e-6f)
{ {
return make_function("near", [](auto x, auto y, auto tol) { return std::abs(x - y) < tol; })( return make_function("near", [](auto x, auto y, auto tol) { return std::abs(x - y) < tol; })(
px, py, ptol); px, py, ptol);
......
...@@ -82,9 +82,9 @@ TEST_CASE(generate_module) ...@@ -82,9 +82,9 @@ TEST_CASE(generate_module)
auto f = compile_module<float(float, float)>(m); auto f = compile_module<float(float, float)>(m);
EXPECT(test::near(f(2, 2), 2)); EXPECT(test::within_abs(f(2, 2), 2));
EXPECT(test::near(f(10, 6), 4)); EXPECT(test::within_abs(f(10, 6), 4));
EXPECT(test::near(f(1, 2), std::sqrt(3))); EXPECT(test::within_abs(f(1, 2), std::sqrt(3)));
} }
TEST_CASE(generate_module_with_literals) TEST_CASE(generate_module_with_literals)
...@@ -99,9 +99,9 @@ TEST_CASE(generate_module_with_literals) ...@@ -99,9 +99,9 @@ TEST_CASE(generate_module_with_literals)
auto f = compile_module<float(float, float)>(m); auto f = compile_module<float(float, float)>(m);
EXPECT(test::near(f(1, 2), 2)); EXPECT(test::within_abs(f(1, 2), 2));
EXPECT(test::near(f(9, 6), 4)); EXPECT(test::within_abs(f(9, 6), 4));
EXPECT(test::near(f(0, 2), std::sqrt(3))); EXPECT(test::within_abs(f(0, 2), std::sqrt(3)));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
This diff is collapsed.
3be6eb53c8b359703cb645ed2cb1cdf106924b7c 21a71d52bd2074b770807b209939ec11e2c64fa7
...@@ -6165,6 +6165,101 @@ def shape_test(): ...@@ -6165,6 +6165,101 @@ def shape_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def shape_dyn_test0():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT,
[None, 4, None, None])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [4])
node = onnx.helper.make_node(
'Shape',
inputs=['x'],
outputs=['y'],
)
return ([node], [x], [y])
@onnx_test()
def shape_dyn_test1():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT,
[None, 4, None, None])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [2])
node = onnx.helper.make_node('Shape', inputs=['x'], outputs=['y'], start=2)
return ([node], [x], [y])
@onnx_test()
def shape_dyn_test2():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT,
[None, 4, None, None])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [2])
node = onnx.helper.make_node('Shape',
inputs=['x'],
outputs=['y'],
start=-2)
return ([node], [x], [y])
@onnx_test()
def shape_dyn_test3():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT,
[None, 4, None, None])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [2])
node = onnx.helper.make_node('Shape',
inputs=['x'],
outputs=['y'],
start=1,
end=2)
return ([node], [x], [y])
@onnx_test()
def shape_end_oob_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT,
[None, 4, None, None])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [2])
node = onnx.helper.make_node('Shape', inputs=['x'], outputs=['y'], end=5)
return ([node], [x], [y])
@onnx_test()
def shape_start_oob_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT,
[None, 4, None, None])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [2])
node = onnx.helper.make_node('Shape',
inputs=['x'],
outputs=['y'],
start=-6)
return ([node], [x], [y])
@onnx_test()
def shape_end_less_start_error():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT,
[None, 4, None, None])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [2])
node = onnx.helper.make_node('Shape',
inputs=['x'],
outputs=['y'],
start=3,
end=1)
return ([node], [x], [y])
@onnx_test() @onnx_test()
def shape_gather_test(): def shape_gather_test():
values = np.array([1]) values = np.array([1])
......
...@@ -440,14 +440,13 @@ TEST_CASE(batch_norm_flat_test) ...@@ -440,14 +440,13 @@ TEST_CASE(batch_norm_flat_test)
auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {1}}); auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {1}});
auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {1}}); auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {1}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}});
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), {var_eps});
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, bias});
auto prog = optimize_onnx("batch_norm_flat_test.onnx"); auto prog = optimize_onnx("batch_norm_flat_test.onnx");
...@@ -465,14 +464,13 @@ TEST_CASE(batch_norm_rank_2_test) ...@@ -465,14 +464,13 @@ TEST_CASE(batch_norm_rank_2_test)
auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {5}}); auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {5}});
auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {5}}); auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {5}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}});
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), {var_eps});
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, bias});
auto prog = optimize_onnx("batch_norm_rank_2_test.onnx"); auto prog = optimize_onnx("batch_norm_rank_2_test.onnx");
...@@ -490,7 +488,6 @@ TEST_CASE(batch_norm_1d_test) ...@@ -490,7 +488,6 @@ TEST_CASE(batch_norm_1d_test)
auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {3}}); auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {3}});
auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {3}}); auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {3}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-5f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-5f}});
auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), scale); auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), scale);
...@@ -498,11 +495,11 @@ TEST_CASE(batch_norm_1d_test) ...@@ -498,11 +495,11 @@ TEST_CASE(batch_norm_1d_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), mean); auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), var); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), var);
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto prog = optimize_onnx("batch_norm_1d_test.onnx"); auto prog = optimize_onnx("batch_norm_1d_test.onnx");
...@@ -520,7 +517,6 @@ TEST_CASE(batch_norm_2d_test) ...@@ -520,7 +517,6 @@ TEST_CASE(batch_norm_2d_test)
auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {3}}); auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {3}});
auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {3}}); auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {3}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}});
auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
...@@ -528,11 +524,11 @@ TEST_CASE(batch_norm_2d_test) ...@@ -528,11 +524,11 @@ TEST_CASE(batch_norm_2d_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var);
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto prog = optimize_onnx("batch_norm_2d_test.onnx"); auto prog = optimize_onnx("batch_norm_2d_test.onnx");
...@@ -550,7 +546,6 @@ TEST_CASE(batch_norm_3d_test) ...@@ -550,7 +546,6 @@ TEST_CASE(batch_norm_3d_test)
auto mean = mm->add_parameter("mean", {migraphx::shape::half_type, {2}}); auto mean = mm->add_parameter("mean", {migraphx::shape::half_type, {2}});
auto var = mm->add_parameter("variance", {migraphx::shape::half_type, {2}}); auto var = mm->add_parameter("variance", {migraphx::shape::half_type, {2}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-6f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-6f}});
auto usq_scale = auto usq_scale =
...@@ -561,12 +556,13 @@ TEST_CASE(batch_norm_3d_test) ...@@ -561,12 +556,13 @@ TEST_CASE(batch_norm_3d_test)
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3}}}), mean); mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3}}}), var); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3}}}), var);
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto prog = optimize_onnx("batch_norm_3d_test.onnx"); auto prog = optimize_onnx("batch_norm_3d_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -908,7 +904,6 @@ TEST_CASE(constant_test) ...@@ -908,7 +904,6 @@ TEST_CASE(constant_test)
TEST_CASE(constant_fill_test) TEST_CASE(constant_fill_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
...@@ -1105,7 +1100,6 @@ TEST_CASE(conv_bn_relu_maxpool_test) ...@@ -1105,7 +1100,6 @@ TEST_CASE(conv_bn_relu_maxpool_test)
auto p5 = mm->add_parameter("5", {migraphx::shape::float_type, {1}}); auto p5 = mm->add_parameter("5", {migraphx::shape::float_type, {1}});
auto p6 = mm->add_parameter("6", {migraphx::shape::float_type, {1}}); auto p6 = mm->add_parameter("6", {migraphx::shape::float_type, {1}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}});
uint64_t axis = 1; uint64_t axis = 1;
...@@ -1120,25 +1114,12 @@ TEST_CASE(conv_bn_relu_maxpool_test) ...@@ -1120,25 +1114,12 @@ TEST_CASE(conv_bn_relu_maxpool_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), p5); auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), p5);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), p6); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), p6);
auto mb_mean = mm->add_instruction( auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {l5, usq_mean});
migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 28, 28}}}), usq_mean); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto numer = mm->add_instruction(migraphx::make_op("sub"), l5, mb_mean); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto mb_eps = auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 1}}}), eps); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
auto var_eps = mm->add_instruction(migraphx::make_op("add"), usq_var, mb_eps); auto l6 = add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto mb_rt =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 1}}}), rt);
auto denom = mm->add_instruction(migraphx::make_op("pow"), var_eps, mb_rt);
auto mb_denom = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 28, 28}}}), denom);
auto div0 = mm->add_instruction(migraphx::make_op("div"), numer, mb_denom);
auto mb_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 28, 28}}}), usq_scale);
auto r0 = mm->add_instruction(migraphx::make_op("mul"), div0, mb_scale);
auto mb_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 28, 28}}}), usq_bias);
auto l6 = mm->add_instruction(migraphx::make_op("add"), r0, mb_bias);
auto l7 = mm->add_instruction(migraphx::make_op("relu"), l6); auto l7 = mm->add_instruction(migraphx::make_op("relu"), l6);
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
...@@ -6079,6 +6060,118 @@ TEST_CASE(shape_test) ...@@ -6079,6 +6060,118 @@ TEST_CASE(shape_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(shape_dyn_test0)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}};
auto p0 = mm->add_parameter("x", s);
migraphx::shape s_shape{migraphx::shape::int64_type, {4}};
auto ret = mm->add_instruction(migraphx::make_op("dimensions_of", {{"end", 4}}), p0);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}};
auto prog = parse_onnx("shape_dyn_test0.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(shape_dyn_test1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}};
auto p0 = mm->add_parameter("x", s);
migraphx::shape s_shape{migraphx::shape::int64_type, {4}};
auto ret =
mm->add_instruction(migraphx::make_op("dimensions_of", {{"start", 2}, {"end", 4}}), p0);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}};
auto prog = parse_onnx("shape_dyn_test1.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(shape_dyn_test2)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}};
auto p0 = mm->add_parameter("x", s);
migraphx::shape s_shape{migraphx::shape::int64_type, {4}};
auto ret =
mm->add_instruction(migraphx::make_op("dimensions_of", {{"start", 2}, {"end", 4}}), p0);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}};
auto prog = parse_onnx("shape_dyn_test2.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(shape_dyn_test3)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}};
auto p0 = mm->add_parameter("x", s);
migraphx::shape s_shape{migraphx::shape::int64_type, {4}};
auto ret =
mm->add_instruction(migraphx::make_op("dimensions_of", {{"start", 1}, {"end", 2}}), p0);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}};
auto prog = parse_onnx("shape_dyn_test3.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(shape_end_oob_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}};
auto p0 = mm->add_parameter("x", s);
migraphx::shape s_shape{migraphx::shape::int64_type, {4}};
auto ret = mm->add_instruction(migraphx::make_op("dimensions_of", {{"end", 4}}), p0);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}};
auto prog = migraphx::parse_onnx("shape_end_oob_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(shape_start_oob_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}}};
auto p0 = mm->add_parameter("x", s);
migraphx::shape s_shape{migraphx::shape::int64_type, {4}};
auto ret = mm->add_instruction(migraphx::make_op("dimensions_of", {{"end", 4}}), p0);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}};
auto prog = migraphx::parse_onnx("shape_start_oob_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(shape_end_less_start_error)
{
migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{1, 4, {1, 4}}, {4, 4}, {2, 4}, {2, 4}};
EXPECT(test::throws([&] { migraphx::parse_onnx("shape_end_less_start_error.onnx", options); }));
}
TEST_CASE(shape_gather_test) TEST_CASE(shape_gather_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -7150,7 +7243,8 @@ TEST_CASE(variable_batch_user_input_test6) ...@@ -7150,7 +7243,8 @@ TEST_CASE(variable_batch_user_input_test6)
TEST_CASE(variable_batch_user_input_test7) TEST_CASE(variable_batch_user_input_test7)
{ {
// if entry in map_dyn_input_dims is all fixed dynamic_dimensions, convert it to a static shape // if entry in map_dyn_input_dims is all fixed dynamic_dimensions, convert it to a static
// shape
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}});
......
No preview for this file type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment