Unverified Commit c3e02b18 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add support in mlir for transposed and broadcasted shaped (#1378)



* Enable non-standard shape
* Use perfdb for non xdlops
* Fix transpose+broadcast strides
Co-authored-by: default avatarjungpark-mlir <jungwook.park@amd.com>
parent 83784c52
...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@e8e77eb16be413d301ea8509726d47f265d9011f -DBUILD_MIXR_TARGET=On RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@c0723a7e50043d973cb73ae51dc30d36679ee7e5 -DBUILD_MIXR_TARGET=On
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -145,7 +145,7 @@ void verify_reduced(program p, ...@@ -145,7 +145,7 @@ void verify_reduced(program p,
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1); auto last = std::prev(mm->end(), n + 1);
mm->remove_instructions(last, mm->end()); mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << std::endl; std::cout << "Verify: " << n << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance); verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance);
} }
...@@ -159,6 +159,7 @@ void verify_reduced_program(const program& p, ...@@ -159,6 +159,7 @@ void verify_reduced_program(const program& p,
{ {
const auto* mm = p.get_main_module(); const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end()); auto n = std::distance(mm->begin(), mm->end());
std::cout << "Verify steps: " << n << std::endl;
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
verify_reduced(p, i, t, options, quantize, inputs, tolerance); verify_reduced(p, i, t, options, quantize, inputs, tolerance);
......
...@@ -49,7 +49,7 @@ struct mlir_conv ...@@ -49,7 +49,7 @@ struct mlir_conv
std::string name() const { return "gpu::mlir_conv"; } std::string name() const { return "gpu::mlir_conv"; }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{ {
check_shapes{inputs, *this}.standard(); check_shapes{inputs, *this}.packed_or_broadcasted();
if(mods.size() != 1) if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2) if(inputs.size() < 2)
...@@ -70,6 +70,9 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) ...@@ -70,6 +70,9 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
auto group = v.at("group").to<int>(); auto group = v.at("group").to<int>();
if(group != 1) if(group != 1)
return false; return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
return true; return true;
} }
...@@ -96,9 +99,10 @@ struct find_conv_pointwise ...@@ -96,9 +99,10 @@ struct find_conv_pointwise
i.name()); i.name());
})) }))
return; return;
// Only fuse with fp32 for now // Only fuse with fp32/fp16
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) { if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->get_shape().type() != shape::type_t::float_type; return not contains({shape::type_t::float_type, shape::type_t::half_type},
i->get_shape().type());
})) }))
return; return;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
......
...@@ -36,7 +36,8 @@ struct module; ...@@ -36,7 +36,8 @@ struct module;
namespace gpu { namespace gpu {
std::string dump_mlir(const module& m); std::string dump_mlir(const module& m);
code_object_op compile_mlir(const context& ctx, const module& m); code_object_op
compile_mlir(const context& ctx, module m, const std::vector<instruction_ref>& inputs);
instruction_ref insert_mlir(module& m, instruction_ref insert_mlir(module& m,
instruction_ref ins, instruction_ref ins,
......
...@@ -41,7 +41,7 @@ struct problem_params ...@@ -41,7 +41,7 @@ struct problem_params
shape output; shape output;
}; };
std::string get_mlir_perf_for_conv(const problem_params& pp); std::string get_mlir_perf_for_conv(const problem_params& pp, bool xdlops);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -41,7 +41,7 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -41,7 +41,7 @@ struct mlir_compiler : compiler<mlir_compiler>
{ {
auto* smod = ins->module_inputs().front(); auto* smod = ins->module_inputs().front();
assert(smod->get_parameter_names().size() == ins->inputs().size() - 1); assert(smod->get_parameter_names().size() == ins->inputs().size() - 1);
return insert(compile_mlir(ctx, *smod)); return insert(compile_mlir(ctx, *smod, ins->inputs()));
} }
compiler_replace insert(code_object_op co) const compiler_replace insert(code_object_op co) const
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/make_op.hpp"
#include <migraphx/gpu/mlir.hpp> #include <migraphx/gpu/mlir.hpp>
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
...@@ -43,8 +44,9 @@ ...@@ -43,8 +44,9 @@
#include <migraphx/gpu/code_object_op.hpp> #include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/perfdb.hpp> #include <migraphx/gpu/perfdb.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <deque> #include <deque>
#include <variant> #include <variant>
...@@ -194,7 +196,6 @@ struct mlir_program ...@@ -194,7 +196,6 @@ struct mlir_program
MlirType make_tensor(const shape& s) const MlirType make_tensor(const shape& s) const
{ {
assert(s.standard());
std::vector<int64_t> lens(s.lens().begin(), s.lens().end()); std::vector<int64_t> lens(s.lens().begin(), s.lens().end());
return mlirRankedTensorTypeGet( return mlirRankedTensorTypeGet(
lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull()); lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull());
...@@ -502,11 +503,12 @@ struct mlir_program ...@@ -502,11 +503,12 @@ struct mlir_program
{ {
pp = pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()}; problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
std::string tuned = get_tune_params(); // check if HW supports xdlops
bool xdlops = contains(get_xdlops_archs(), target_name);
std::string tuned = get_tune_params(xdlops);
if(not tuned.empty()) if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}}); ops.add_attributes({{"perf_config", tuned}});
// check if HW supports xdlops if(xdlops)
if(contains(get_xdlops_archs(), target_name))
ops.add_attributes({{"xdlopsV2", true}}); ops.add_attributes({{"xdlopsV2", true}});
} }
...@@ -571,7 +573,7 @@ struct mlir_program ...@@ -571,7 +573,7 @@ struct mlir_program
MIGRAPHX_THROW("Failed to compile mlir program"); MIGRAPHX_THROW("Failed to compile mlir program");
} }
std::string get_tune_params() { return get_mlir_perf_for_conv(pp); } std::string get_tune_params(bool xdlops) { return get_mlir_perf_for_conv(pp, xdlops); }
mlir_context ctx; mlir_context ctx;
MlirLocation location; MlirLocation location;
...@@ -589,8 +591,54 @@ std::string dump_mlir(const module& m) ...@@ -589,8 +591,54 @@ std::string dump_mlir(const module& m)
return mlir_print(&mlirOperationPrint, mod_op); return mlir_print(&mlirOperationPrint, mod_op);
} }
code_object_op compile_mlir(const context&, const module& m) void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs)
{ {
auto names = m.get_parameter_names();
std::sort(names.begin(), names.end());
for(auto i : range(names.size()))
{
const auto& name = names[i];
const auto& input = inputs[i]->get_shape();
auto param = m.get_parameter(name);
if(input.standard())
continue;
auto lens = input.lens();
auto strides = input.strides();
std::vector<operation> ops;
if(input.transposed())
{
auto perm = find_permutation(input);
auto iperm = invert_permutation(perm);
lens = reorder_dims(lens, iperm);
strides = reorder_dims(strides, iperm);
ops.push_back(make_op("transpose", {{"permutation", perm}}));
}
if(input.broadcasted())
{
std::transform(lens.begin(),
lens.end(),
strides.begin(),
lens.begin(),
[](auto len, auto stride) -> std::size_t {
if(stride == 0)
return 1;
return len;
});
ops.push_back(make_op("multibroadcast", {{"out_lens", input.lens()}}));
}
auto new_param =
std::accumulate(ops.begin(),
ops.end(),
m.add_parameter(name + ".0", shape{input.type(), lens}),
[&](auto x, auto op) { return m.insert_instruction(param, op, x); });
m.replace_instruction(param, new_param);
m.remove_instruction(param);
}
}
code_object_op compile_mlir(const context&, module m, const std::vector<instruction_ref>& inputs)
{
adjust_param_shapes(m, inputs);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace) if(trace)
std::cout << m << std::endl; std::cout << m << std::endl;
...@@ -662,13 +710,19 @@ instruction_ref insert_mlir(module& m, ...@@ -662,13 +710,19 @@ instruction_ref insert_mlir(module& m,
std::string dump_mlir(const module&) { return {}; } std::string dump_mlir(const module&) { return {}; }
code_object_op compile_mlir(const context&, const module&) { return {}; }
template <class T> template <class T>
void use(T&) void use(T&)
{ {
} }
// Disabling clang-tidy warning on non-real useage.
// NOLINTBEGIN(performance-unnecessary-value-param)
code_object_op compile_mlir(const context&, module, const std::vector<instruction_ref>&)
{
return {};
}
// NOLINTEND(performance-unnecessary-value-param)
instruction_ref instruction_ref
// cppcheck-suppress funcArgNamesDifferent // cppcheck-suppress funcArgNamesDifferent
insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&) insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&)
......
...@@ -108,16 +108,17 @@ auto query_miopen_db(const std::string& query) ...@@ -108,16 +108,17 @@ auto query_miopen_db(const std::string& query)
} // namespace } // namespace
std::string get_mlir_perf_for_conv(const problem_params& pp) std::string get_mlir_perf_for_conv(const problem_params& pp, bool xdlops)
{ {
std::string solver = xdlops ? "ConvMlirIgemmFwdXdlops" : "ConvMlirIgemmFwd";
std::string query = "select P.* \ std::string query = "select P.* \
from perf_db P, config C \ from perf_db P, config C \
where P.config = C.id AND \ where P.config = C.id AND \
P.solver = 'ConvMlirIgemmFwdXdlops' AND \ P.solver = '${solver}' AND \
${config}"; ${config}";
auto results = auto results = query_miopen_db(
query_miopen_db(interpolate_string(query, {{"config", generate_miopen_config(pp)}})); interpolate_string(query, {{"config", generate_miopen_config(pp)}, {"solver", solver}}));
if(results.empty()) if(results.empty())
return ""; return "";
return results.front().at("params"); return results.front().at("params");
......
...@@ -84,7 +84,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir) ...@@ -84,7 +84,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir)
inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front())); inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front()));
migraphx::gpu::context ctx; migraphx::gpu::context ctx;
migraphx::gpu::insert_mlir(*mm, mm->end(), compile_mlir(ctx, mmlir), inputs); migraphx::gpu::insert_mlir(*mm, mm->end(), compile_mlir(ctx, mmlir, inputs), inputs);
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment