Commit f12064ee authored by umangyadav's avatar umangyadav
Browse files

Merge branch 'develop' into resnet50_partition

parents 2c4f70be 6f1c947f
...@@ -36,7 +36,10 @@ ...@@ -36,7 +36,10 @@
#include <mutex> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3 #if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling" #warning "Incompatible version of rocMLIR library used, disabling"
// Only undefine when not using cppcheck
#ifndef CPPCHECK
#undef MIGRAPHX_MLIR #undef MIGRAPHX_MLIR
#endif
#else #else
#include <mlir-c/RegisterRocMLIR.h> #include <mlir-c/RegisterRocMLIR.h>
#endif #endif
...@@ -50,8 +53,10 @@ ...@@ -50,8 +53,10 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/gpu/code_object_op.hpp> #include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/perfdb.hpp> #include <migraphx/gpu/perfdb.hpp>
#include <migraphx/gpu/tuning_config.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <deque> #include <deque>
...@@ -134,6 +139,10 @@ using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockD ...@@ -134,6 +139,10 @@ using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockD
using mlir_pass_manager = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy); using mlir_pass_manager = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy);
using mlir_tuning_table = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningTable, using mlir_tuning_table = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningTable,
mlirRockTuningTableDestroy); mlirRockTuningTableDestroy);
using mlir_tuning_space = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningSpace,
mlirRockTuningSpaceDestroy);
using mlir_tuning_param = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningParam,
mlirRockTuningParamDestroy);
std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; } std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; }
...@@ -167,12 +176,6 @@ std::string mlir_print(F f, T x) ...@@ -167,12 +176,6 @@ std::string mlir_print(F f, T x)
return ss.str(); return ss.str();
} }
bool has_xdlops(const std::string& target_arch)
{
const auto device_name = trim(split_string(target_arch, ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
}
struct mlir_program struct mlir_program
{ {
mlir_program() mlir_program()
...@@ -507,7 +510,8 @@ struct mlir_program ...@@ -507,7 +510,8 @@ struct mlir_program
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)}, ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", sym_name}, {"sym_name", sym_name},
{"kernel", std::string("mixr")}, {"kernel", std::string("mixr")},
{"arch", target_arch}}); {"arch", target_arch},
{"num_cu", num_cu}});
ops.add_region(std::move(region)); ops.add_region(std::move(region));
insert(body, std::move(ops)); insert(body, std::move(ops));
...@@ -554,14 +558,7 @@ struct mlir_program ...@@ -554,14 +558,7 @@ struct mlir_program
static std::string get_symbol_name(const module& m) static std::string get_symbol_name(const module& m)
{ {
for(auto ins : iterator_for(m)) return "mlir_" + gen::generate_name_from_ops(m);
{
if(ins->name() == "convolution" or ins->name() == "dot")
{
return "mlir_" + ins->name();
}
}
return "main";
} }
void parse(const module& m) void parse(const module& m)
...@@ -597,9 +594,6 @@ struct mlir_program ...@@ -597,9 +594,6 @@ struct mlir_program
{ {
pp = pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()}; problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops
if(has_xdlops(target_arch))
ops.add_attributes({{"xdlopsV2", true}});
} }
std::vector<MlirValue> inputs; std::vector<MlirValue> inputs;
...@@ -616,18 +610,30 @@ struct mlir_program ...@@ -616,18 +610,30 @@ struct mlir_program
} }
} }
code_object_op compile() MIGRAPHX_TIDY_CONST void run_high_level_pipeline() MIGRAPHX_TIDY_CONST
{ {
mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())};
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm_front.get()); mlirMIGraphXAddHighLevelPipeline(pm_front.get());
mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get())); mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get()));
}
// 2nd pipeline to call void run_backend_pipeline() MIGRAPHX_TIDY_CONST
get_module_tuned(); {
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str()); mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str());
mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get())); mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get()));
}
code_object_op compile(const value& solution) MIGRAPHX_TIDY_CONST
{
// 1st pipeline to call
run_high_level_pipeline();
if(solution.is_null())
get_module_tuned();
else
set_tuning(solution);
// 2nd pipeline to call
run_backend_pipeline();
code_object_op op{}; code_object_op op{};
op.symbol_name = sym_name; op.symbol_name = sym_name;
...@@ -636,7 +642,12 @@ struct mlir_program ...@@ -636,7 +642,12 @@ struct mlir_program
return op; return op;
} }
void find_target() { target_arch = get_device_name(); } void set_gpu_properties(const context& migraphx_ctx)
{
const auto& device = migraphx_ctx.get_current_device();
target_arch = device.get_device_name();
num_cu = device.get_cu_count();
}
std::pair<std::size_t, std::size_t> get_launch_params() const std::pair<std::size_t, std::size_t> get_launch_params() const
{ {
...@@ -650,7 +661,7 @@ struct mlir_program ...@@ -650,7 +661,7 @@ struct mlir_program
value::binary get_binary() const value::binary get_binary() const
{ {
int size = 0; size_t size = 0;
mlirGetBinary(mmodule.get(), &size, nullptr); mlirGetBinary(mmodule.get(), &size, nullptr);
value::binary result(size); value::binary result(size);
if(mlirGetBinary(mmodule.get(), &size, reinterpret_cast<char*>(result.data()))) if(mlirGetBinary(mmodule.get(), &size, reinterpret_cast<char*>(result.data())))
...@@ -658,14 +669,52 @@ struct mlir_program ...@@ -658,14 +669,52 @@ struct mlir_program
MIGRAPHX_THROW("Failed to compile mlir program"); MIGRAPHX_THROW("Failed to compile mlir program");
} }
void set_tuning(const value& v) MIGRAPHX_TIDY_CONST
{
const auto* str = v.if_string();
if(str == nullptr)
MIGRAPHX_THROW("mlir tuning solutions must be strings");
if(not mlirRockTuningSetFromStr(mmodule.get(), make_mlir_string_ref(*str)))
MIGRAPHX_THROW("Failed setting tuning key: " + *str);
}
tuning_config get_tuning_config() MIGRAPHX_TIDY_CONST
{
tuning_config tc;
run_high_level_pipeline();
mlir_tuning_space params{
mlirRockTuningSpaceCreate(mmodule.get(), RocmlirTuningParamSetKindFull)};
for(auto i : range(mlirRockTuningGetNumParams(params.get())))
{
mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get()))
MIGRAPHX_THROW("Incorrect mlir tuning parameter: " + std::to_string(i));
std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> perf_key;
size_t perf_key_bytes =
mlirRockTuningParamToString(param.get(), perf_key.data(), perf_key.size());
if(perf_key_bytes > perf_key.size())
MIGRAPHX_THROW("Tuning perf key was " + std::to_string(perf_key_bytes) +
" bytes and thus too long");
tc.solutions.emplace_back(perf_key.begin(), perf_key.begin() + perf_key_bytes);
}
std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> tuning_key;
size_t tuning_key_bytes =
mlirRockTuningGetKey(mmodule.get(), tuning_key.data(), tuning_key.size());
if(tuning_key_bytes > tuning_key.size())
MIGRAPHX_THROW("Tuning table key was " + std::to_string(tuning_key_bytes) +
" bytes and thus too long");
tc.problem = std::string(tuning_key.begin(), tuning_key.begin() + tuning_key_bytes);
return tc;
}
std::string get_tune_params(bool xdlops) const { return get_mlir_perf_for_conv(pp, xdlops); } std::string get_tune_params(bool xdlops) const { return get_mlir_perf_for_conv(pp, xdlops); }
// This function appends to tuning cfg file that could be // This function appends to tuning cfg file that could be
// used with rocMLIR tuning scripts. // used with rocMLIR tuning scripts.
void dump_tuning_cfg(const char* prob_config) const void dump_tuning_cfg(const std::string& prob_config) const
{ {
std::string tuning_cfg_path = string_value_of(MIGRAPHX_MLIR_TUNING_CFG{}); std::string tuning_cfg_path = string_value_of(MIGRAPHX_MLIR_TUNING_CFG{});
if(!tuning_cfg_path.empty()) if(not tuning_cfg_path.empty())
{ {
std::vector<std::string> tokens = split_string(prob_config, '\t'); std::vector<std::string> tokens = split_string(prob_config, '\t');
std::string prob = tokens[1]; std::string prob = tokens[1];
...@@ -682,51 +731,66 @@ struct mlir_program ...@@ -682,51 +731,66 @@ struct mlir_program
} }
} }
static mlir_tuning_table create_tuning_table() static std::pair<mlir_tuning_table, bool> load_tuning_table()
{ {
mlir_tuning_table tuning_table{mlirRockTuningTableCreate()}; mlir_tuning_table tuning_table{mlirRockTuningTableCreate()};
bool found_table = false;
std::string tuning_db_path = string_value_of(MIGRAPHX_MLIR_TUNING_DB{}); std::string tuning_db_path = string_value_of(MIGRAPHX_MLIR_TUNING_DB{});
if(!tuning_db_path.empty()) if(not tuning_db_path.empty())
{ {
std::ifstream tuning_db_tsv(tuning_db_path); std::ifstream tuning_db_tsv(tuning_db_path);
if(tuning_db_tsv) if(tuning_db_tsv)
{ {
found_table = true;
std::string line; std::string line;
while(std::getline(tuning_db_tsv, line)) while(std::getline(tuning_db_tsv, line))
{ {
std::vector<std::string> tokens = split_string(line, '\t'); std::vector<std::string> tokens = split_string(line, '\t');
std::string arch = tokens[0]; std::string arch = tokens[0];
std::string prob = tokens[1]; std::string num_cu = tokens[1];
std::string perf = tokens[2]; std::string prob = tokens[2];
std::string key = arch.append("\t").append(prob); std::string perf = tokens[3];
mlirRockTuningUpdateTable(tuning_table.get(), key.c_str(), perf.c_str(), 1.0); std::string key = arch.append("\t").append(num_cu).append("\t").append(prob);
mlirRockTuningUpdateTable(tuning_table.get(),
make_mlir_string_ref(key),
make_mlir_string_ref(perf),
1.0);
} }
} }
} }
else else
{ {
found_table = false;
std::cerr std::cerr
<< "WARNING: MLIR tuning db not found. Please set MIGRAPHX_MLIR_TUNING_DB for " << "WARNING: MLIR tuning db not found. Please set MIGRAPHX_MLIR_TUNING_DB for "
"optimal performance." "optimal performance."
<< std::endl; << std::endl;
} }
return tuning_table; return std::make_pair(std::move(tuning_table), found_table);
} }
bool get_module_tuned() const bool get_module_tuned() const
{ {
static mlir_tuning_table tuning_table = create_tuning_table(); static std::pair<mlir_tuning_table, bool> tuning_table = load_tuning_table();
// The tuning table as currently implemented is currently not if(not mlirRockTuningSetFromTable(tuning_table.first.get(), mmodule.get()))
// thread safe. This will be fixed in the future. For now,
// stick a mutex around all tuning table interaction.
static std::mutex lock;
std::lock_guard<std::mutex> guard(lock);
if(!mlirRockTuningSetFromTable(tuning_table.get(), mmodule.get()))
{ {
const char* prob_config = mlirRockTuningGetKey(tuning_table.get(), mmodule.get()); std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> prob_config;
std::stringstream key(prob_config); size_t prob_config_bytes =
std::cerr << "fails to set param on" << prob_config << std::endl; mlirRockTuningGetKey(mmodule.get(), prob_config.data(), prob_config.size());
dump_tuning_cfg(prob_config); if(prob_config_bytes >= prob_config.size())
{
std::cerr << "MLIR tuning key overflowed buffer, needed " << prob_config_bytes
<< " bytes" << std::endl;
return false;
}
std::string prob_config_str(prob_config.begin(),
prob_config.begin() + prob_config_bytes);
if(tuning_table.second)
{
std::cerr << "NOTE: MLIR tuning table did not include a key for " << prob_config_str
<< std::endl;
}
dump_tuning_cfg(prob_config_str);
return false; return false;
} }
return true; return true;
...@@ -737,7 +801,8 @@ struct mlir_program ...@@ -737,7 +801,8 @@ struct mlir_program
mlir_module mmodule; mlir_module mmodule;
problem_params pp; problem_params pp;
std::deque<std::string> strings{}; std::deque<std::string> strings{};
std::string target_arch; std::string target_arch = "";
std::size_t num_cu = 0;
std::string sym_name; std::string sym_name;
}; };
...@@ -749,14 +814,14 @@ std::string dump_mlir(const module& m) ...@@ -749,14 +814,14 @@ std::string dump_mlir(const module& m)
return mlir_print(&mlirOperationPrint, mod_op); return mlir_print(&mlirOperationPrint, mod_op);
} }
void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs) void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
{ {
auto names = m.get_parameter_names(); auto names = m.get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
for(auto i : range(names.size())) for(auto i : range(names.size()))
{ {
const auto& name = names[i]; const auto& name = names[i];
const auto& input = inputs[i]->get_shape(); const auto& input = inputs[i];
auto param = m.get_parameter(name); auto param = m.get_parameter(name);
if(input.standard()) if(input.standard())
continue; continue;
...@@ -794,22 +859,26 @@ void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs) ...@@ -794,22 +859,26 @@ void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs)
} }
} }
code_object_op compile_mlir(const context&, module m, const std::vector<instruction_ref>& inputs) code_object_op compile_mlir(const context& migraphx_ctx,
module m,
const std::vector<instruction_ref>& inputs,
const value& solution)
{ {
adjust_param_shapes(m, inputs); adjust_param_shapes(m, to_shapes(inputs));
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace) if(trace)
std::cout << m << std::endl; std::cout << m << std::endl;
mlir_program mp; mlir_program mp;
mp.find_target(); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace) if(trace)
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
auto co = mp.compile(); auto co = mp.compile(solution);
co.output = m.get_output_shapes().front(); co.expected_inputs = to_shapes(inputs);
co.output = m.get_output_shapes().front();
return co; return co;
} }
...@@ -829,6 +898,17 @@ instruction_ref insert_mlir(module& m, ...@@ -829,6 +898,17 @@ instruction_ref insert_mlir(module& m,
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
} }
tuning_config
get_tuning_config_mlir(const context& migraphx_ctx, module m, const std::vector<shape>& inputs)
{
adjust_param_shapes(m, inputs);
mlir_program mp;
mp.set_gpu_properties(migraphx_ctx);
mp.parse(m);
return mp.get_tuning_config();
}
#else #else
std::string dump_mlir(const module&) { return {}; } std::string dump_mlir(const module&) { return {}; }
...@@ -840,20 +920,27 @@ void use(T&) ...@@ -840,20 +920,27 @@ void use(T&)
// Disabling clang-tidy warning on non-real useage. // Disabling clang-tidy warning on non-real useage.
// NOLINTBEGIN(performance-unnecessary-value-param) // NOLINTBEGIN(performance-unnecessary-value-param)
code_object_op compile_mlir(const context&, module, const std::vector<instruction_ref>&) code_object_op
compile_mlir(const context&, module, const std::vector<instruction_ref>&, const value&)
{ {
return {}; return {};
} }
// NOLINTEND(performance-unnecessary-value-param)
instruction_ref instruction_ref
// cppcheck-suppress funcArgNamesDifferent // cppcheck-suppress funcArgNamesDifferent
insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&) insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&)
{ {
use(co); use(co);
use(m);
return m.end(); return m.end();
} }
tuning_config get_tuning_config_mlir(const context&, module, const std::vector<shape>&)
{
return {};
}
// NOLINTEND(performance-unnecessary-value-param)
#endif #endif
} // namespace gpu } // namespace gpu
......
...@@ -34,7 +34,7 @@ namespace gpu { ...@@ -34,7 +34,7 @@ namespace gpu {
std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsigned long seed = 0) std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsigned long seed = 0)
{ {
std::vector<argument> args; std::vector<argument> args;
std::transform(shapes.begin(), shapes.end(), std::back_inserter(args), [&](auto& s) { std::transform(shapes.begin(), shapes.end(), std::back_inserter(args), [&](const auto& s) {
return to_gpu(generate_argument(s, seed++)); return to_gpu(generate_argument(s, seed++));
}); });
return args; return args;
......
...@@ -338,7 +338,7 @@ void tf_parser::parse_node(const std::string& name) ...@@ -338,7 +338,7 @@ void tf_parser::parse_node(const std::string& name)
std::string input_name = input; std::string input_name = input;
// if input has trailing `:0` index then remove it // if input has trailing `:0` index then remove it
auto multi_out_idx = input.find(':'); auto multi_out_idx = input.find(':');
if(multi_out_idx != std::string::npos && input.substr(multi_out_idx + 1) == "0") if(multi_out_idx != std::string::npos and input.substr(multi_out_idx + 1) == "0")
{ {
input_name = input.substr(0, multi_out_idx); input_name = input.substr(0, multi_out_idx);
} }
......
...@@ -285,7 +285,7 @@ bool value::contains(const std::string& pkey) const ...@@ -285,7 +285,7 @@ bool value::contains(const std::string& pkey) const
} }
std::size_t value::size() const std::size_t value::size() const
{ {
auto* a = if_array_impl(x); const auto* a = if_array_impl(x);
if(a == nullptr) if(a == nullptr)
return 0; return 0;
return a->size(); return a->size();
......
...@@ -98,17 +98,11 @@ endfunction() ...@@ -98,17 +98,11 @@ endfunction()
function(add_test_executable TEST_NAME) function(add_test_executable TEST_NAME)
add_executable(${TEST_NAME} EXCLUDE_FROM_ALL ${ARGN}) add_executable(${TEST_NAME} EXCLUDE_FROM_ALL ${ARGN})
target_link_libraries(${TEST_NAME} ${CMAKE_THREAD_LIBS_INIT})
# Cmake does not add flags correctly for gcc
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU")
set_target_properties(${TEST_NAME} PROPERTIES COMPILE_FLAGS -pthread LINK_FLAGS -pthread)
endif()
set(TEST_COMMAND ${TEST_NAME}) set(TEST_COMMAND ${TEST_NAME})
add_test_command(${TEST_NAME} ${TEST_COMMAND}) add_test_command(${TEST_NAME} ${TEST_COMMAND})
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME}) add_dependencies(check ${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx migraphx_onnx migraphx_ref) target_link_libraries(${TEST_NAME} Threads::Threads migraphx migraphx_onnx migraphx_ref)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable) endfunction(add_test_executable)
...@@ -208,11 +202,16 @@ endif() ...@@ -208,11 +202,16 @@ endif()
function(test_header NAME HEADER) function(test_header NAME HEADER)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/header-main-include-${NAME}.cpp file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/header-main-include-${NAME}.cpp "
"#include <${HEADER}>\nint main() {}\n" #include <${HEADER}>
int main() {}\n"
) )
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/header-static-include-${NAME}.cpp file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/header-static-include-${NAME}.cpp "
"#include <${HEADER}>\n" #include <${HEADER}>
#if defined(min) || defined(max) || defined(near) || defined(far)
#error \"Do not include windows.h in header files\"
#endif
\n"
) )
add_test_executable(${NAME} add_test_executable(${NAME}
${CMAKE_CURRENT_BINARY_DIR}/header-main-include-${NAME}.cpp ${CMAKE_CURRENT_BINARY_DIR}/header-main-include-${NAME}.cpp
......
...@@ -145,15 +145,15 @@ TEST_CASE(zero_parameter) ...@@ -145,15 +145,15 @@ TEST_CASE(zero_parameter)
TEST_CASE(set_scalar_parameter) TEST_CASE(set_scalar_parameter)
{ {
auto p1 = migraphx::parse_onnx("add_bcast_test.onnx"); auto p1 = migraphx::parse_onnx("implicit_add_bcast_test.onnx");
migraphx::shape s1(migraphx_shape_float_type, {3, 4}); migraphx::shape s1(migraphx_shape_float_type, {3, 4, 1});
auto param_shapes = p1.get_parameter_shapes(); auto param_shapes = p1.get_parameter_shapes();
auto s1_orig = param_shapes["1"]; auto s1_orig = param_shapes["1"];
CHECK(bool{s1 == s1_orig}); CHECK(bool{s1 == s1_orig});
migraphx::onnx_options option; migraphx::onnx_options option;
option.set_input_parameter_shape("1", {}); option.set_input_parameter_shape("1", {});
auto p2 = migraphx::parse_onnx("add_bcast_test.onnx", option); auto p2 = migraphx::parse_onnx("implicit_add_bcast_test.onnx", option);
migraphx::shape s_scalar(migraphx_shape_float_type); migraphx::shape s_scalar(migraphx_shape_float_type);
auto param_shapes_1 = p2.get_parameter_shapes(); auto param_shapes_1 = p2.get_parameter_shapes();
auto s_scalar_after = param_shapes_1["1"]; auto s_scalar_after = param_shapes_1["1"];
......
...@@ -99,7 +99,7 @@ TEST_CASE(run_sigmoid_custom_op) ...@@ -99,7 +99,7 @@ TEST_CASE(run_sigmoid_custom_op)
EXPECT(bool{result == migraphx::argument(s, expected_result.data())}); EXPECT(bool{result == migraphx::argument(s, expected_result.data())});
} }
extern "C" void migraphx_test_private_disable_exception_catch(bool b); extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool);
TEST_CASE(run_sigmoid_with_incorrect_shape) TEST_CASE(run_sigmoid_with_incorrect_shape)
{ {
......
...@@ -196,15 +196,47 @@ TEST_CASE(contiguous_pointwise) ...@@ -196,15 +196,47 @@ TEST_CASE(contiguous_pointwise)
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 3, 8, 8}}}), y); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 3, 8, 8}}}), y);
auto yc = mm->add_instruction(migraphx::make_op("contiguous"), yb); auto yc = mm->add_instruction(migraphx::make_op("contiguous"), yb);
auto add = add_pointwise(p, "main:pointwise0", {x, yc}, single_pointwise("add")); auto add = add_pointwise(p, "main:pointwise0", {x, yc}, single_pointwise("add"));
mm->add_instruction(pass_op{}, add); auto cadd = mm->add_instruction(migraphx::make_op("contiguous"), add);
mm->add_instruction(pass_op{}, cadd);
} }
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(*mm); run_pass(*mm);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2));
EXPECT(std::none_of( EXPECT(std::none_of(
mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "contiguous"; })); mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "contiguous"; }));
} }
TEST_CASE(contiguous_nhwc_pointwise)
{
auto s =
migraphx::shape::from_permutation(migraphx::shape::float_type, {2, 3, 8, 8}, {0, 2, 3, 1});
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {3}});
auto yb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 3, 8, 8}}}), y);
auto yc = mm->add_instruction(migraphx::make_op("contiguous"), yb);
auto add = add_pointwise(p1, "main:pointwise0", {x, yc}, single_pointwise("add"));
auto cadd = mm->add_instruction(migraphx::make_op("contiguous"), add);
mm->add_instruction(pass_op{}, cadd);
}
run_pass(*p1.get_main_module());
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {3}});
auto yb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 3, 8, 8}}}), y);
auto add = add_pointwise(p2, "main:pointwise0", {x, yb}, single_pointwise("add"));
auto cadd = mm->add_instruction(migraphx::make_op("contiguous"), add);
mm->add_instruction(pass_op{}, cadd);
}
EXPECT(p1 == p2);
}
TEST_CASE(slice_contiguous) TEST_CASE(slice_contiguous)
{ {
migraphx::module m; migraphx::module m;
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -58,9 +58,8 @@ create_conv(migraphx::instruction_ref& l_img, ...@@ -58,9 +58,8 @@ create_conv(migraphx::instruction_ref& l_img,
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}}; migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3); std::vector<int32_t> weights(4 * channels * 3 * 3);
auto l_weights = m.add_literal(migraphx::literal{s_weights, weights}); auto l_weights = m.add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op; return m.add_instruction(
op.padding_mode = padding_mode; migraphx::make_op("convolution", {{"padding_mode", padding_mode}}), l_img, l_weights);
return m.add_instruction(op, l_img, l_weights);
} }
TEST_CASE(rewrite_pad) TEST_CASE(rewrite_pad)
......
...@@ -84,7 +84,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir) ...@@ -84,7 +84,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir)
inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front())); inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front()));
migraphx::gpu::context ctx; migraphx::gpu::context ctx;
migraphx::gpu::insert_mlir(*mm, mm->end(), compile_mlir(ctx, mmlir, inputs), inputs); migraphx::gpu::insert_mlir(*mm, mm->end(), compile_mlir(ctx, mmlir, inputs, {}), inputs);
return p; return p;
} }
...@@ -140,7 +140,7 @@ TEST_CASE(conv) ...@@ -140,7 +140,7 @@ TEST_CASE(conv)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_convolution(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} { func.func @mlir_convolution(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
return %0 : tensor<1x2x2x2xf32> return %0 : tensor<1x2x2x2xf32>
} }
...@@ -163,7 +163,7 @@ TEST_CASE(conv_add_relu) ...@@ -163,7 +163,7 @@ TEST_CASE(conv_add_relu)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_convolution(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} { func.func @mlir_convolution_add_relu(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
...@@ -191,7 +191,7 @@ TEST_CASE(quant_dot_add) ...@@ -191,7 +191,7 @@ TEST_CASE(quant_dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @main(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr"} { func.func @mlir_quant_dot_add(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32> %0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32> %1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32>
return %1 : tensor<1x5x3xi32> return %1 : tensor<1x5x3xi32>
...@@ -218,7 +218,7 @@ TEST_CASE(dot_add) ...@@ -218,7 +218,7 @@ TEST_CASE(dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} { func.func @mlir_dot_add(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32> return %1 : tensor<1x5x3xf32>
...@@ -244,7 +244,7 @@ TEST_CASE(conv_int8_dequantize_quantize) ...@@ -244,7 +244,7 @@ TEST_CASE(conv_int8_dequantize_quantize)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @main(%arg0: tensor<2x8x3x3xi8>, %arg1: tensor<1x8x4x4xi8>, %arg2: tensor<1x2x2x2xf32>, %arg3: tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> attributes {arch = "", kernel = "mixr"} { func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: tensor<2x8x3x3xi8>, %arg1: tensor<1x8x4x4xi8>, %arg2: tensor<1x2x2x2xf32>, %arg3: tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xi8>, tensor<2x8x3x3xi8>) -> tensor<1x2x2x2xi32> %0 = migraphx.quant_convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xi8>, tensor<2x8x3x3xi8>) -> tensor<1x2x2x2xi32>
%1 = migraphx.dequantizelinear(%0, %arg2, %arg3) : (tensor<1x2x2x2xi32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xf32> %1 = migraphx.dequantizelinear(%0, %arg2, %arg3) : (tensor<1x2x2x2xi32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.quantizelinear(%1, %arg2, %arg3) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> %2 = migraphx.quantizelinear(%1, %arg2, %arg3) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32>
...@@ -277,7 +277,7 @@ TEST_CASE(dot_convert) ...@@ -277,7 +277,7 @@ TEST_CASE(dot_convert)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr"} { func.func @mlir_dot_convert(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16> %1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16>
return %1 : tensor<1x5x3xf16> return %1 : tensor<1x5x3xf16>
...@@ -303,7 +303,7 @@ TEST_CASE(dot_where) ...@@ -303,7 +303,7 @@ TEST_CASE(dot_where)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} { func.func @mlir_dot_where(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32> %0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32> return %1 : tensor<1x5x3xf32>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/gpu/fuse_mlir.hpp> #include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -90,7 +90,7 @@ TEST_CASE(int8_quantization) ...@@ -90,7 +90,7 @@ TEST_CASE(int8_quantization)
migraphx::shape sc{migraphx::shape::float_type, {5, 8}}; migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
mm->add_instruction(migraphx::op::dot{}, pa, pb); mm->add_instruction(migraphx::make_op("dot"), pa, pb);
return p; return p;
}; };
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <atomic>
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>
...@@ -342,11 +343,19 @@ inline std::ostream& operator<<(std::ostream& os, const color& c) ...@@ -342,11 +343,19 @@ inline std::ostream& operator<<(std::ostream& os, const color& c)
return os; return os;
} }
inline std::atomic<int>& failures()
{
// NOLINTNEXTLINE
static std::atomic<int> f = 0;
return f;
}
template <class T, class F> template <class T, class F>
void failed(T x, const char* msg, const char* func, const char* file, int line, F f) void failed(T x, const char* msg, const char* func, const char* file, int line, F f)
{ {
if(not bool(x.value())) if(not bool(x.value()))
{ {
failures()++;
std::cout << func << std::endl; std::cout << func << std::endl;
std::cout << file << ":" << line << ":" << std::endl; std::cout << file << ":" << line << ":" << std::endl;
std::cout << color::bold << color::fg_red << " FAILED: " << color::reset << msg << " " std::cout << color::bold << color::fg_red << " FAILED: " << color::reset << msg << " "
...@@ -586,13 +595,21 @@ struct driver ...@@ -586,13 +595,21 @@ struct driver
{ {
try try
{ {
failures() = 0;
f(); f();
} }
// cppcheck-suppress EmptyCatchStatement
catch(const failure_error&) catch(const failure_error&)
{ {
msg = "Test failure";
} }
} }
if(msg.empty() and failures() != 0)
{
if(failures() == 1)
msg = "Test failure";
else
msg = std::to_string(failures()) + " test failures";
}
if(msg.empty()) if(msg.empty())
{ {
out() << color::fg_green << "[ COMPLETE ] " << color::reset << color::bold << name out() << color::fg_green << "[ COMPLETE ] " << color::reset << color::bold << name
...@@ -683,10 +700,10 @@ inline void run(int argc, const char* argv[]) ...@@ -683,10 +700,10 @@ inline void run(int argc, const char* argv[])
#define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__ #define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define CHECK(...) \ #define CHECK(...) \
test::failed( \ test::failed( \
test::capture{}->*__VA_ARGS__, #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] { \ TEST_CAPTURE(__VA_ARGS__), #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] {})
})
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define EXPECT(...) \ #define EXPECT(...) \
test::failed(TEST_CAPTURE(__VA_ARGS__), \ test::failed(TEST_CAPTURE(__VA_ARGS__), \
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
......
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
#include <migraphx/insert_pad.hpp> #include <migraphx/insert_pad.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/common.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -58,10 +58,11 @@ create_conv(migraphx::instruction_ref& l_img, ...@@ -58,10 +58,11 @@ create_conv(migraphx::instruction_ref& l_img,
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}}; migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3); std::vector<int32_t> weights(4 * channels * 3 * 3);
auto l_weights = m.add_literal(migraphx::literal{s_weights, weights}); auto l_weights = m.add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op; return m.add_instruction(
op.padding_mode = padding_mode; migraphx::make_op("convolution",
op.padding = {0, 0, 1, 1}; {{"padding_mode", padding_mode}, {"padding", {0, 0, 1, 1}}}),
return m.add_instruction(op, l_img, l_weights); l_img,
l_weights);
} }
TEST_CASE(rewrite_pad) TEST_CASE(rewrite_pad)
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <migraphx/layout_nhwc.hpp> #include <migraphx/layout_nhwc.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
......
...@@ -89,17 +89,13 @@ bool is_overlap_load(migraphx::instruction_ref a, migraphx::instruction_ref b) ...@@ -89,17 +89,13 @@ bool is_overlap_load(migraphx::instruction_ref a, migraphx::instruction_ref b)
bool is_disjoint(const std::vector<migraphx::instruction_ref>& inss) bool is_disjoint(const std::vector<migraphx::instruction_ref>& inss)
{ {
for(auto ins1 : inss) return std::none_of(inss.begin(), inss.end(), [&](auto ins1) {
{ return std::none_of(inss.begin(), inss.end(), [&](auto ins2) {
for(auto ins2 : inss)
{
if(ins1 == ins2) if(ins1 == ins2)
continue; return true;
if(is_overlap_load(ins1, ins2)) return is_overlap_load(ins1, ins2);
return false; });
} });
}
return true;
} }
TEST_CASE(test1) TEST_CASE(test1)
......
...@@ -83,7 +83,7 @@ TEST_CASE(calc_implict_deps) ...@@ -83,7 +83,7 @@ TEST_CASE(calc_implict_deps)
auto* else_mod = p.create_module("If_5_else"); auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(ys, datay)); auto l2 = else_mod->add_literal(migraphx::literal(ys, datay));
auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1}); auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
auto a3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), a2); auto a3 = else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), a2);
else_mod->add_return({a3, l2}); else_mod->add_return({a3, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
...@@ -95,6 +95,15 @@ TEST_CASE(calc_implict_deps) ...@@ -95,6 +95,15 @@ TEST_CASE(calc_implict_deps)
EXPECT(migraphx::contains(implicit_deps.at(ret), x1)); EXPECT(migraphx::contains(implicit_deps.at(ret), x1));
EXPECT(migraphx::contains(implicit_deps.at(ret), x2)); EXPECT(migraphx::contains(implicit_deps.at(ret), x2));
EXPECT(migraphx::contains(implicit_deps.at(ret), y2)); EXPECT(migraphx::contains(implicit_deps.at(ret), y2));
EXPECT(migraphx::contains(implicit_deps.at(ret), lx));
EXPECT(migraphx::contains(implicit_deps.at(ret), ly));
// test for sorting
p.sort();
auto ret_inputs = ret->inputs();
ret_inputs.insert(ret_inputs.end(), implicit_deps.at(ret).begin(), implicit_deps.at(ret).end());
EXPECT(std::all_of(ret_inputs.begin(), ret_inputs.end(), [&](const auto i) {
return std::distance(mm->begin(), i) < std::distance(mm->begin(), ret);
}));
} }
TEST_CASE(module_annotate) TEST_CASE(module_annotate)
......
d3295f4329d744fe1f8419e1220e123807282b99 a476dbf430ac8315550474a78d47bf182f202d7c
...@@ -6414,6 +6414,30 @@ def slice_test(): ...@@ -6414,6 +6414,30 @@ def slice_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def slice_constant_test():
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 2])
x_tensor = helper.make_tensor(name='x_tensor',
data_type=TensorProto.FLOAT,
dims=[3, 2],
vals=[0, 1, 2, 3, 4, 5])
x = onnx.helper.make_node('Constant',
inputs=[],
outputs=['x'],
value=x_tensor)
node = onnx.helper.make_node('Slice',
inputs=['x'],
axes=[0, 1],
starts=[1, 0],
ends=[2, 2],
outputs=['1'])
return ([x, node], [], [y])
@onnx_test() @onnx_test()
def slice_dyn_test(): def slice_dyn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, None, 2]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, None, 2])
...@@ -6746,6 +6770,92 @@ def slice_max_end_test(): ...@@ -6746,6 +6770,92 @@ def slice_max_end_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def slice_var_input_static0():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.INT32, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.INT32, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node('Slice',
inputs=['data', 'starts', 'ends'],
axes=[0, 1],
outputs=['output'])
return ([node], [data, starts, ends], [output])
@onnx_test()
def slice_var_input_static1():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.INT64, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.INT64, [2])
axes = helper.make_tensor_value_info('axes', TensorProto.INT64, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node('Slice',
inputs=['data', 'starts', 'ends', 'axes'],
outputs=['output'])
return ([node], [data, starts, ends, axes], [output])
@onnx_test()
def slice_var_input_dyn0():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [None, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.INT32, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.INT32, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node('Slice',
inputs=['data', 'starts', 'ends'],
axes=[0, 1],
outputs=['output'])
return ([node], [data, starts, ends], [output])
@onnx_test()
def slice_var_input_dyn1():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [None, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.INT32, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.INT32, [2])
axes = helper.make_tensor_value_info('axes', TensorProto.INT32, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node('Slice',
inputs=['data', 'starts', 'ends', 'axes'],
outputs=['output'])
return ([node], [data, starts, ends, axes], [output])
@onnx_test()
def slice_var_input_steps_error():
step = np.array([2, 1])
step_tensor = helper.make_tensor(name="step",
data_type=TensorProto.INT32,
dims=step.shape,
vals=step.astype(int))
arg_step = helper.make_node("Constant",
inputs=[],
outputs=['arg_step'],
value=step_tensor)
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.FLOAT, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.FLOAT, [2])
axes = helper.make_tensor_value_info('axes', TensorProto.FLOAT, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node(
'Slice',
inputs=['data', 'starts', 'ends', 'axes', 'arg_step'],
outputs=['output'])
return ([arg_step, node], [data, starts, ends, axes], [output])
@onnx_test() @onnx_test()
def softmax_test(): def softmax_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3])
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment