"include/vscode:/vscode.git/clone" did not exist on "057ffb90846eef043c9fd5d45f4168892e23dfdc"
Unverified Commit cc098f4d authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add flag for tuning in migraphx-driver (#1519)

* Add driver flag "--exhaustive-tune" to enable tuning, add support for the same in C/C++ and python API
parent 102c6bdb
...@@ -35,6 +35,7 @@ TEST_CASE(load_and_run) ...@@ -35,6 +35,7 @@ TEST_CASE(load_and_run)
auto shapes_before = p.get_output_shapes(); auto shapes_before = p.get_output_shapes();
migraphx::compile_options options; migraphx::compile_options options;
options.set_offload_copy(); options.set_offload_copy();
options.set_exhaustive_tune_flag();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
auto shapes_after = p.get_output_shapes(); auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 1); CHECK(shapes_before.size() == 1);
......
...@@ -33,7 +33,8 @@ def test_conv_relu(): ...@@ -33,7 +33,8 @@ def test_conv_relu():
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx") p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
print(p) print(p)
print("Compiling ...") print("Compiling ...")
p.compile(migraphx.get_target("gpu")) # set offload_copy, fast_match and exhaustive_tune to true
p.compile(migraphx.get_target("gpu"), True, True, True)
print(p) print(p)
params = {} params = {}
......
...@@ -134,6 +134,11 @@ void set_offload_copy(compile_options& options, bool value) { options.offload_co ...@@ -134,6 +134,11 @@ void set_offload_copy(compile_options& options, bool value) { options.offload_co
void set_fast_math(compile_options& options, bool value) { options.fast_math = value; } void set_fast_math(compile_options& options, bool value) { options.fast_math = value; }
void set_exhaustive_tune_flag(compile_options& options, bool value)
{
options.exhaustive_tune = value;
}
void set_file_format(file_options& options, const char* format) { options.format = format; } void set_file_format(file_options& options, const char* format) { options.format = format; }
void set_default_dim_value(onnx_options& options, size_t value) void set_default_dim_value(onnx_options& options, size_t value)
......
...@@ -66,6 +66,7 @@ any_ptr get_queue_context(T&) ...@@ -66,6 +66,7 @@ any_ptr get_queue_context(T&)
{ {
return {}; return {};
} }
template <class T> template <class T>
void wait_for_context(T&, any_ptr) void wait_for_context(T&, any_ptr)
{ {
...@@ -87,6 +88,7 @@ void finish_on_context(T&, any_ptr){} ...@@ -87,6 +88,7 @@ void finish_on_context(T&, any_ptr){}
{ {
v = ctx.to_value(); v = ctx.to_value();
} }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); } inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }
#endif #endif
......
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
import string, sys, re import string, sys, re
trivial = [ trivial = [
'std::size_t', 'instruction_ref', 'support_metric', 'const_module_ref' 'std::size_t', 'instruction_ref', 'support_metric', 'const_module_ref',
'bool', 'any_ptr'
] ]
headers = ''' headers = '''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment