Commit d9568511 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 7dc6e3ae 95431eb7
...@@ -73,20 +73,11 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t) ...@@ -73,20 +73,11 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
target get_target(const std::string& name) { return make_target(name); } target get_target(const std::string& name) { return make_target(name); }
migraphx::compile_options to_compile_options(const migraphx_compile_options& options) void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; }
{
migraphx::compile_options result{};
result.offload_copy = options.offload_copy;
result.fast_math = options.fast_math;
return result;
}
migraphx::file_options to_file_options(const migraphx_file_options& options) void set_fast_math(compile_options& options, bool value) { options.fast_math = value; }
{
migraphx::file_options result{}; void set_file_format(file_options& options, const char* format) { options.format = format; }
result.format = options.format;
return result;
}
void set_default_dim_value(onnx_options& options, size_t value) void set_default_dim_value(onnx_options& options, size_t value)
{ {
...@@ -325,6 +316,26 @@ struct migraphx_onnx_options ...@@ -325,6 +316,26 @@ struct migraphx_onnx_options
migraphx::onnx_options object; migraphx::onnx_options object;
}; };
extern "C" struct migraphx_file_options;
struct migraphx_file_options
{
template <class... Ts>
migraphx_file_options(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
migraphx::file_options object;
};
extern "C" struct migraphx_compile_options;
struct migraphx_compile_options
{
template <class... Ts>
migraphx_compile_options(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
migraphx::compile_options object;
};
extern "C" struct migraphx_tf_options; extern "C" struct migraphx_tf_options;
struct migraphx_tf_options struct migraphx_tf_options
{ {
...@@ -683,17 +694,16 @@ extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* o ...@@ -683,17 +694,16 @@ extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* o
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program, extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target, migraphx_target_t target,
migraphx_compile_options* options) migraphx_compile_options_t options)
{ {
return migraphx::try_([&] { return migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(target == nullptr) if(target == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter target: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter target: Null pointer");
(program->object) if(options == nullptr)
.compile((target->object), MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
(options == nullptr ? migraphx::compile_options{} (program->object).compile((target->object), (options->object));
: migraphx::to_compile_options(*options)));
}); });
} }
...@@ -791,25 +801,24 @@ migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operati ...@@ -791,25 +801,24 @@ migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operati
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options* options) migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options)
{ {
return migraphx::try_([&] { return migraphx::try_([&] {
*out = allocate<migraphx_program_t>(migraphx::load( if(options == nullptr)
(name), MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
(options == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*options)))); *out = allocate<migraphx_program_t>(migraphx::load((name), (options->object)));
}); });
} }
extern "C" migraphx_status extern "C" migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options* options) migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options)
{ {
return migraphx::try_([&] { return migraphx::try_([&] {
if(p == nullptr) if(p == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer");
migraphx::save( if(options == nullptr)
(p->object), MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
(name), migraphx::save((p->object), (name), (options->object));
(options == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*options)));
}); });
} }
...@@ -859,6 +868,65 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o ...@@ -859,6 +868,65 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o
}); });
} }
extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options)
{
return migraphx::try_([&] { destroy((file_options)); });
}
extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options)
{
return migraphx::try_([&] {
*file_options = object_cast<migraphx_file_options_t>(allocate<migraphx::file_options>());
});
}
extern "C" migraphx_status
migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format)
{
return migraphx::try_([&] {
if(file_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter file_options: Null pointer");
migraphx::set_file_format((file_options->object), (format));
});
}
extern "C" migraphx_status
migraphx_compile_options_destroy(migraphx_compile_options_t compile_options)
{
return migraphx::try_([&] { destroy((compile_options)); });
}
extern "C" migraphx_status
migraphx_compile_options_create(migraphx_compile_options_t* compile_options)
{
return migraphx::try_([&] {
*compile_options =
object_cast<migraphx_compile_options_t>(allocate<migraphx::compile_options>());
});
}
extern "C" migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value)
{
return migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_offload_copy((compile_options->object), (value));
});
}
extern "C" migraphx_status
migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value)
{
return migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_fast_math((compile_options->object), (value));
});
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options) migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options)
{ {
......
...@@ -41,26 +41,6 @@ typedef enum { ...@@ -41,26 +41,6 @@ typedef enum {
} migraphx_shape_datatype_t; } migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
/// Options to be passed when compiling
typedef struct
{
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// offloaded memory and to copy the final result from the offloaded
/// memory back to main memory.
bool offload_copy;
/// Optimize math functions to use faster approximate versions. There may
/// be slight accuracy degredation when enabled.
bool fast_math;
} migraphx_compile_options;
/// Options for saving and loading files
typedef struct
{
/// Format to be used for file. It can either be json or msgpack
const char* format;
} migraphx_file_options;
typedef struct migraphx_shape* migraphx_shape_t; typedef struct migraphx_shape* migraphx_shape_t;
typedef const struct migraphx_shape* const_migraphx_shape_t; typedef const struct migraphx_shape* const_migraphx_shape_t;
...@@ -94,6 +74,12 @@ typedef const struct migraphx_operation* const_migraphx_operation_t; ...@@ -94,6 +74,12 @@ typedef const struct migraphx_operation* const_migraphx_operation_t;
typedef struct migraphx_onnx_options* migraphx_onnx_options_t; typedef struct migraphx_onnx_options* migraphx_onnx_options_t;
typedef const struct migraphx_onnx_options* const_migraphx_onnx_options_t; typedef const struct migraphx_onnx_options* const_migraphx_onnx_options_t;
typedef struct migraphx_file_options* migraphx_file_options_t;
typedef const struct migraphx_file_options* const_migraphx_file_options_t;
typedef struct migraphx_compile_options* migraphx_compile_options_t;
typedef const struct migraphx_compile_options* const_migraphx_compile_options_t;
typedef struct migraphx_tf_options* migraphx_tf_options_t; typedef struct migraphx_tf_options* migraphx_tf_options_t;
typedef const struct migraphx_tf_options* const_migraphx_tf_options_t; typedef const struct migraphx_tf_options* const_migraphx_tf_options_t;
...@@ -200,7 +186,7 @@ migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, ...@@ -200,7 +186,7 @@ migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_status migraphx_program_compile(migraphx_program_t program, migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target, migraphx_target_t target,
migraphx_compile_options* options); migraphx_compile_options_t options);
migraphx_status migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out, migraphx_status migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out,
migraphx_program_t program); migraphx_program_t program);
...@@ -228,10 +214,10 @@ migraphx_status migraphx_operation_create(migraphx_operation_t* operation, ...@@ -228,10 +214,10 @@ migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation); migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation);
migraphx_status migraphx_status
migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options* options); migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options);
migraphx_status migraphx_status
migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options* options); migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options);
migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options); migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
...@@ -247,6 +233,23 @@ migraphx_status ...@@ -247,6 +233,23 @@ migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value); int64_t value);
migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options);
migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options);
migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options,
const char* format);
migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options);
migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options);
migraphx_status
migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value);
migraphx_status migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options,
bool value);
migraphx_status migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options); migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options);
......
...@@ -494,6 +494,29 @@ struct module ...@@ -494,6 +494,29 @@ struct module
void print() const { call(&migraphx_module_print, mm); } void print() const { call(&migraphx_module_print, mm); }
}; };
struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{
compile_options() { this->make_handle(&migraphx_compile_options_create); }
compile_options(migraphx_compile_options* p, own) { this->set_handle(p, own()); }
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// offloaded memory and to copy the final result from the offloaded
/// memory back to main memory.
void set_offload_copy(bool value = true)
{
call(&migraphx_compile_options_set_offload_copy, this->get_handle_ptr(), value);
}
/// Optimize math functions to use faster approximate versions. There may
/// be slight accuracy degredation when enabled.
void set_fast_math(bool value = true)
{
call(&migraphx_compile_options_set_fast_math, this->get_handle_ptr(), value);
}
};
/// A program represents the all computation graphs to be compiled and executed /// A program represents the all computation graphs to be compiled and executed
struct program : MIGRAPHX_HANDLE_BASE(program) struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
...@@ -504,16 +527,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -504,16 +527,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
program(migraphx_program* p, borrow) { this->set_handle(p, borrow{}); } program(migraphx_program* p, borrow) { this->set_handle(p, borrow{}); }
/// Compile the program for a specific target to be ran on /// Compile the program for a specific target to be ran on
void compile(const target& ptarget, migraphx_compile_options poptions) const void compile(const target& ptarget, const compile_options& poptions) const
{ {
call( call(&migraphx_program_compile,
&migraphx_program_compile, this->get_handle_ptr(), ptarget.get_handle_ptr(), &poptions); this->get_handle_ptr(),
ptarget.get_handle_ptr(),
poptions.get_handle_ptr());
} }
/// Compile the program for a specific target to be ran on /// Compile the program for a specific target to be ran on
void compile(const target& ptarget) const void compile(const target& ptarget) const
{ {
call(&migraphx_program_compile, this->get_handle_ptr(), ptarget.get_handle_ptr(), nullptr); call(&migraphx_program_compile,
this->get_handle_ptr(),
ptarget.get_handle_ptr(),
migraphx::compile_options{}.get_handle_ptr());
} }
/// Return the shapes for the input parameters /// Return the shapes for the input parameters
...@@ -584,28 +612,45 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -584,28 +612,45 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
} }
}; };
// options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{
file_options() { this->make_handle(&migraphx_file_options_create); }
file_options(migraphx_file_options* p, own) { this->set_handle(p, own()); }
// set file format
void set_file_format(const char* format)
{
call(&migraphx_file_options_set_file_format, this->get_handle_ptr(), format);
}
};
/// Load a saved migraphx program from a file /// Load a saved migraphx program from a file
inline program load(const char* filename, migraphx_file_options options) inline program load(const char* filename, const file_options& options)
{ {
return program(make<migraphx_program>(&migraphx_load, filename, &options), own{}); return program(make<migraphx_program>(&migraphx_load, filename, options.get_handle_ptr()),
own{});
} }
/// Load a saved migraphx program from a file /// Load a saved migraphx program from a file
inline program load(const char* filename) inline program load(const char* filename)
{ {
return program(make<migraphx_program>(&migraphx_load, filename, nullptr), own{}); return program(
make<migraphx_program>(&migraphx_load, filename, migraphx::file_options{}.get_handle_ptr()),
own{});
} }
/// Save a program to a file /// Save a program to a file
inline void save(const program& p, const char* filename, migraphx_file_options options) inline void save(const program& p, const char* filename, const file_options& options)
{ {
call(&migraphx_save, p.get_handle_ptr(), filename, &options); call(&migraphx_save, p.get_handle_ptr(), filename, options.get_handle_ptr());
} }
/// Save a program to a file /// Save a program to a file
inline void save(const program& p, const char* filename) inline void save(const program& p, const char* filename)
{ {
call(&migraphx_save, p.get_handle_ptr(), filename, nullptr); call(&migraphx_save, p.get_handle_ptr(), filename, migraphx::file_options{}.get_handle_ptr());
} }
/// Options for parsing onnx options /// Options for parsing onnx options
......
...@@ -250,6 +250,25 @@ def onnx_options(h): ...@@ -250,6 +250,25 @@ def onnx_options(h):
) )
@auto_handle()
def file_options(h):
h.constructor('create')
h.method('set_file_format',
api.params(format='const char*'),
invoke='migraphx::set_file_format($@)')
@auto_handle()
def compile_options(h):
h.constructor('create')
h.method('set_offload_copy',
api.params(value='bool'),
invoke='migraphx::set_offload_copy($@)')
h.method('set_fast_math',
api.params(value='bool'),
invoke='migraphx::set_fast_math($@)')
api.add_function('migraphx_parse_onnx', api.add_function('migraphx_parse_onnx',
api.params(name='const char*', api.params(name='const char*',
options='migraphx::onnx_options'), options='migraphx::onnx_options'),
......
...@@ -46,6 +46,9 @@ struct module ...@@ -46,6 +46,9 @@ struct module
std::string name() const; std::string name() const;
bool bypass() const;
void set_bypass(bool b = true);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)> template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref add_instruction(operation op, Ts... args) instruction_ref add_instruction(operation op, Ts... args)
{ {
......
...@@ -28,6 +28,7 @@ struct module_impl ...@@ -28,6 +28,7 @@ struct module_impl
std::unordered_set<instruction*> instruction_set; std::unordered_set<instruction*> instruction_set;
std::string name; std::string name;
uint32_t nparams = 0; uint32_t nparams = 0;
bool bypass = false;
bool contains(instruction_ref ins) const bool contains(instruction_ref ins) const
{ {
...@@ -49,6 +50,13 @@ struct module_impl ...@@ -49,6 +50,13 @@ struct module_impl
return emplace(pos, ins); return emplace(pos, ins);
} }
void clear()
{
instructions.clear();
instruction_set.clear();
nparams = 0;
}
void push_front(const instruction& ins) { insert(instructions.begin(), ins); } void push_front(const instruction& ins) { insert(instructions.begin(), ins); }
void push_back(const instruction& ins) { insert(instructions.end(), ins); } void push_back(const instruction& ins) { insert(instructions.end(), ins); }
...@@ -100,18 +108,21 @@ module& module::operator=(module m) ...@@ -100,18 +108,21 @@ module& module::operator=(module m)
std::string module::name() const { return impl->name; } std::string module::name() const { return impl->name; }
bool module::bypass() const { return impl->bypass; }
void module::set_bypass(bool b) { impl->bypass = b; }
void module::assign(const module& m) void module::assign(const module& m)
{ {
// clean the current module // copy the impl
if(!impl) if(!impl)
{
impl = std::make_unique<module_impl>(); impl = std::make_unique<module_impl>();
} *impl = *m.impl;
else if(!impl->instructions.empty())
// clear instructions
if(!impl->instructions.empty())
{ {
impl->instructions.clear(); impl->clear();
} }
impl->name = m.impl->name;
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
......
...@@ -95,6 +95,8 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) ...@@ -95,6 +95,8 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
auto mods = prog.get_modules(); auto mods = prog.get_modules();
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(mods))
{ {
if(mod->bypass())
continue;
module_pm{mod, &prog, &trace}.run_pass(p); module_pm{mod, &prog, &trace}.run_pass(p);
} }
run_pass(prog, p, trace); run_pass(prog, p, trace);
......
...@@ -3,14 +3,14 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR) ...@@ -3,14 +3,14 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME}) set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC}) add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
rocm_clang_tidy_check(${NAME}) rocm_clang_tidy_check(${NAME})
target_link_libraries(${NAME} migraphx_c) target_link_libraries(${NAME} migraphx_c migraphx)
target_include_directories(${NAME} PUBLIC ../include) target_include_directories(${NAME} PUBLIC ../include)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR}) add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME}) add_dependencies(tests ${NAME})
add_dependencies(check ${NAME}) add_dependencies(check ${NAME})
endfunction() endfunction()
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include <migraphx/compile_options.hpp>
#include "test.hpp"
TEST_CASE(compile_options_api_test)
{
migraphx::api::compile_options options;
options.set_offload_copy(false);
options.set_fast_math(false);
const auto* s_options = reinterpret_cast<const migraphx::MIGRAPHX_INLINE_NS::compile_options*>(
options.get_handle_ptr());
CHECK(s_options->fast_math == false);
CHECK(s_options->offload_copy == false);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -7,8 +7,8 @@ TEST_CASE(load_and_run) ...@@ -7,8 +7,8 @@ TEST_CASE(load_and_run)
{ {
auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
auto shapes_before = p.get_output_shapes(); auto shapes_before = p.get_output_shapes();
migraphx_compile_options options; migraphx::compile_options options;
options.offload_copy = true; options.set_offload_copy();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
auto shapes_after = p.get_output_shapes(); auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 1); CHECK(shapes_before.size() == 1);
...@@ -30,8 +30,8 @@ TEST_CASE(if_pl_test) ...@@ -30,8 +30,8 @@ TEST_CASE(if_pl_test)
auto run_prog = [&](auto cond) { auto run_prog = [&](auto cond) {
auto p = migraphx::parse_onnx("if_pl_test.onnx"); auto p = migraphx::parse_onnx("if_pl_test.onnx");
auto shapes_before = p.get_output_shapes(); auto shapes_before = p.get_output_shapes();
migraphx_compile_options options; migraphx::compile_options options;
options.offload_copy = true; options.set_offload_copy();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
auto shapes_after = p.get_output_shapes(); auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 1); CHECK(shapes_before.size() == 1);
...@@ -81,8 +81,8 @@ TEST_CASE(loop_test) ...@@ -81,8 +81,8 @@ TEST_CASE(loop_test)
parse_options.set_default_loop_iterations(max_iter_num); parse_options.set_default_loop_iterations(max_iter_num);
auto p = migraphx::parse_onnx("loop_default_test.onnx", parse_options); auto p = migraphx::parse_onnx("loop_default_test.onnx", parse_options);
auto shapes_before = p.get_output_shapes(); auto shapes_before = p.get_output_shapes();
migraphx_compile_options options; migraphx::compile_options options;
options.offload_copy = true; options.set_offload_copy();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
auto shapes_after = p.get_output_shapes(); auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 2); CHECK(shapes_before.size() == 2);
......
...@@ -22,8 +22,8 @@ TEST_CASE(load_save_json) ...@@ -22,8 +22,8 @@ TEST_CASE(load_save_json)
std::string filename = "migraphx_api_load_save.json"; std::string filename = "migraphx_api_load_save.json";
auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
auto s1 = p1.get_output_shapes(); auto s1 = p1.get_output_shapes();
migraphx_file_options options; migraphx::file_options options;
options.format = "json"; options.set_file_format("json");
migraphx::save(p1, filename.c_str(), options); migraphx::save(p1, filename.c_str(), options);
auto p2 = migraphx::load(filename.c_str(), options); auto p2 = migraphx::load(filename.c_str(), options);
......
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <sstream> #include <sstream>
...@@ -276,4 +277,39 @@ TEST_CASE(parameter_name_order) ...@@ -276,4 +277,39 @@ TEST_CASE(parameter_name_order)
EXPECT(param_names == names1); EXPECT(param_names == names1);
} }
struct check_for_pass_op
{
bool* found = nullptr;
std::string name() const { return "check_for_pass_op"; }
void apply(migraphx::module& m) const
{
*found |= std::any_of(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "pass"; });
}
};
TEST_CASE(module_bypass)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* sub = p.create_module("sub");
sub->set_bypass();
sub->add_instruction(pass_op{});
mm->add_instruction(mod_pass_op{}, {}, {sub});
bool found = false;
migraphx::run_passes(p, {check_for_pass_op{&found}});
EXPECT(not found);
}
TEST_CASE(module_without_bypass)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* sub = p.create_module("sub");
sub->add_instruction(pass_op{});
mm->add_instruction(mod_pass_op{}, {}, {sub});
bool found = false;
migraphx::run_passes(p, {check_for_pass_op{&found}});
EXPECT(found);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -73,20 +73,11 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t) ...@@ -73,20 +73,11 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
target get_target(const std::string& name) { return make_target(name); } target get_target(const std::string& name) { return make_target(name); }
migraphx::compile_options to_compile_options(const migraphx_compile_options& options) void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; }
{
migraphx::compile_options result{};
result.offload_copy = options.offload_copy;
result.fast_math = options.fast_math;
return result;
}
migraphx::file_options to_file_options(const migraphx_file_options& options) void set_fast_math(compile_options& options, bool value) { options.fast_math = value; }
{
migraphx::file_options result{}; void set_file_format(file_options& options, const char* format) { options.format = format; }
result.format = options.format;
return result;
}
void set_default_dim_value(onnx_options& options, size_t value) void set_default_dim_value(onnx_options& options, size_t value)
{ {
......
...@@ -41,26 +41,6 @@ typedef enum { ...@@ -41,26 +41,6 @@ typedef enum {
} migraphx_shape_datatype_t; } migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
/// Options to be passed when compiling
typedef struct
{
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// offloaded memory and to copy the final result from the offloaded
/// memory back to main memory.
bool offload_copy;
/// Optimize math functions to use faster approximate versions. There may
/// be slight accuracy degredation when enabled.
bool fast_math;
} migraphx_compile_options;
/// Options for saving and loading files
typedef struct
{
/// Format to be used for file. It can either be json or msgpack
const char* format;
} migraphx_file_options;
<% generate_c_header() %> <% generate_c_header() %>
#ifdef __cplusplus #ifdef __cplusplus
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment