Unverified Commit 205306ac authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Enable mlir quick tuning (#2183)

parent 15acaee9
...@@ -49,7 +49,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m, ...@@ -49,7 +49,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(const context& migraphx_ctx, MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
module m, module m,
const std::vector<shape>& inputs); const std::vector<shape>& inputs,
bool exhaustive);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -57,11 +57,9 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -57,11 +57,9 @@ struct mlir_compiler : compiler<mlir_compiler>
const operation&, const operation&,
bool exhaustive) const bool exhaustive) const
{ {
if(not exhaustive)
return nullopt;
auto shapes = to_shapes(ins->inputs()); auto shapes = to_shapes(ins->inputs());
auto* smod = ins->module_inputs().front(); auto* smod = ins->module_inputs().front();
return get_tuning_config_mlir(ctx, *smod, shapes); return get_tuning_config_mlir(ctx, *smod, shapes, exhaustive);
} }
}; };
......
...@@ -682,11 +682,12 @@ struct mlir_program ...@@ -682,11 +682,12 @@ struct mlir_program
MIGRAPHX_THROW("Failed setting tuning key: " + *str); MIGRAPHX_THROW("Failed setting tuning key: " + *str);
} }
tuning_config get_tuning_config() MIGRAPHX_TIDY_CONST tuning_config get_tuning_config(bool exhaustive) MIGRAPHX_TIDY_CONST
{ {
tuning_config tc; tuning_config tc;
run_high_level_pipeline(); run_high_level_pipeline();
auto tuning_mode = RocmlirTuningParamSetKindFull; auto tuning_mode =
exhaustive ? RocmlirTuningParamSetKindFull : RocmlirTuningParamSetKindQuick;
if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{})) if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{}))
tuning_mode = RocmlirTuningParamSetKindExhaustive; tuning_mode = RocmlirTuningParamSetKindExhaustive;
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)}; mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)};
...@@ -914,15 +915,17 @@ instruction_ref insert_mlir(module& m, ...@@ -914,15 +915,17 @@ instruction_ref insert_mlir(module& m,
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
} }
tuning_config tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
get_tuning_config_mlir(const context& migraphx_ctx, module m, const std::vector<shape>& inputs) module m,
const std::vector<shape>& inputs,
bool exhaustive)
{ {
adjust_param_shapes(m, inputs); adjust_param_shapes(m, inputs);
mlir_program mp; mlir_program mp;
mp.set_gpu_properties(migraphx_ctx); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
return mp.get_tuning_config(); return mp.get_tuning_config(exhaustive);
} }
#else #else
...@@ -951,7 +954,7 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins ...@@ -951,7 +954,7 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
return m.end(); return m.end();
} }
tuning_config get_tuning_config_mlir(const context&, module, const std::vector<shape>&) tuning_config get_tuning_config_mlir(const context&, module, const std::vector<shape>&, bool)
{ {
return {}; return {};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment