Unverified Commit 205306ac authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Enable mlir quick tuning (#2183)

parent 15acaee9
......@@ -49,7 +49,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
module m,
const std::vector<shape>& inputs);
const std::vector<shape>& inputs,
bool exhaustive);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -57,11 +57,9 @@ struct mlir_compiler : compiler<mlir_compiler>
const operation&,
bool exhaustive) const
{
if(not exhaustive)
return nullopt;
auto shapes = to_shapes(ins->inputs());
auto* smod = ins->module_inputs().front();
return get_tuning_config_mlir(ctx, *smod, shapes);
return get_tuning_config_mlir(ctx, *smod, shapes, exhaustive);
}
};
......
......@@ -682,11 +682,12 @@ struct mlir_program
MIGRAPHX_THROW("Failed setting tuning key: " + *str);
}
tuning_config get_tuning_config() MIGRAPHX_TIDY_CONST
tuning_config get_tuning_config(bool exhaustive) MIGRAPHX_TIDY_CONST
{
tuning_config tc;
run_high_level_pipeline();
auto tuning_mode = RocmlirTuningParamSetKindFull;
auto tuning_mode =
exhaustive ? RocmlirTuningParamSetKindFull : RocmlirTuningParamSetKindQuick;
if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{}))
tuning_mode = RocmlirTuningParamSetKindExhaustive;
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)};
......@@ -914,15 +915,17 @@ instruction_ref insert_mlir(module& m,
return m.insert_instruction(ins, co, refs);
}
tuning_config
get_tuning_config_mlir(const context& migraphx_ctx, module m, const std::vector<shape>& inputs)
tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
module m,
const std::vector<shape>& inputs,
bool exhaustive)
{
adjust_param_shapes(m, inputs);
mlir_program mp;
mp.set_gpu_properties(migraphx_ctx);
mp.parse(m);
return mp.get_tuning_config();
return mp.get_tuning_config(exhaustive);
}
#else
......@@ -951,7 +954,7 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
return m.end();
}
tuning_config get_tuning_config_mlir(const context&, module, const std::vector<shape>&)
tuning_config get_tuning_config_mlir(const context&, module, const std::vector<shape>&, bool)
{
return {};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment