Unverified Commit 70d9faf7 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into mi200

parents a56c531c a60bdb67
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp> #include <migraphx/quantize_8bits.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
...@@ -41,8 +41,6 @@ ...@@ -41,8 +41,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
static std::vector<shape::type_t>& get_quantizable_type() static std::vector<shape::type_t>& get_quantizable_type()
{ {
static std::vector<shape::type_t> quantable_types = { static std::vector<shape::type_t> quantable_types = {
...@@ -50,7 +48,7 @@ static std::vector<shape::type_t>& get_quantizable_type() ...@@ -50,7 +48,7 @@ static std::vector<shape::type_t>& get_quantizable_type()
return quantable_types; return quantable_types;
} }
void quantize_int8_pass::apply(module& m) const // NOLINT void quantize_8bits_pass::apply(module& m) const // NOLINT
{ {
const auto& quantizable_types = get_quantizable_type(); const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
...@@ -66,9 +64,10 @@ void quantize_int8_pass::apply(module& m) const // NOLINT ...@@ -66,9 +64,10 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
auto input = ins->inputs().front(); auto input = ins->inputs().front();
auto s = input->get_shape(); auto s = input->get_shape();
if(contains(quantizable_types, s.type()) and s.type() != shape::int8_type) if(contains(quantizable_types, s.type()) and s.type() != precision)
{ {
auto zero_point = m.add_literal(static_cast<int8_t>(param.second)); auto zero_point =
m.add_literal(migraphx::literal{migraphx::shape{precision}, {param.second}});
auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first})); auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first}));
const auto& lens = s.lens(); const auto& lens = s.lens();
scale = scale =
...@@ -87,9 +86,11 @@ void quantize_int8_pass::apply(module& m) const // NOLINT ...@@ -87,9 +86,11 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
void capture_arguments_pass::apply(module& m) const // NOLINT void capture_arguments_pass::apply(module& m) const // NOLINT
{ {
assert(param_index != nullptr); assert(param_index != nullptr);
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(not contains(ins_names, ins->name())) if((not contains(ins_names, ins->name())) or (ins->name() == "convert"))
{ {
continue; continue;
} }
...@@ -98,8 +99,15 @@ void capture_arguments_pass::apply(module& m) const // NOLINT ...@@ -98,8 +99,15 @@ void capture_arguments_pass::apply(module& m) const // NOLINT
std::vector<instruction_ref> new_args; std::vector<instruction_ref> new_args;
for(auto input : inputs) for(auto input : inputs)
{ {
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input); if(contains(quantizable_types, input->get_shape().type()))
new_args.push_back(new_in); {
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
new_args.push_back(new_in);
}
else
{
new_args.push_back(input);
}
} }
m.replace_instruction(ins, ins->get_operator(), new_args); m.replace_instruction(ins, ins->get_operator(), new_args);
} }
......
...@@ -56,7 +56,11 @@ target make_target(const std::string& name) ...@@ -56,7 +56,11 @@ target make_target(const std::string& name)
{ {
if(not contains(target_map(), name)) if(not contains(target_map(), name))
{ {
#ifdef _WIN32
std::string target_name = "migraphx_" + name + ".dll";
#else
std::string target_name = "libmigraphx_" + name + ".so"; std::string target_name = "libmigraphx_" + name + ".so";
#endif
store_target_lib(dynamic_loader(target_name)); store_target_lib(dynamic_loader(target_name));
} }
const auto it = target_map().find(name); const auto it = target_map().find(name);
......
...@@ -35,6 +35,110 @@ ...@@ -35,6 +35,110 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
static void replace_with_reduce(module& m, instruction_ref ins)
{
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
auto lens = s.lens();
std::vector<std::int64_t> axes(lens.size() - 2);
std::iota(axes.begin(), axes.end(), 2);
// average pooling
if(op.mode == op::pooling_mode::average)
{
m.replace_instruction(ins, make_op("reduce_mean", {{"axes", axes}}), ins->inputs());
}
// max pooling
else
{
m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs());
}
}
static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins)
{
// TODO remove this when MIOpen supports dilated pooling
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
// Ignore N, C axes
std::vector<size_t> dims = {s.lens().cbegin() + 2, s.lens().cend()};
bool default_padding =
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
if(not default_padding)
{
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// We need to pad both ends
dims[idx] += op.padding.at(idx) * 2;
}
}
std::vector<size_t> kernels = op.lengths;
std::vector<size_t> strides = op.stride;
std::vector<size_t> dilations = op.dilations;
std::vector<std::vector<int>> axis_indices;
axis_indices.resize(dims.size());
for(auto idx{0}; idx < dims.size(); ++idx)
{
// Only consider if iw fits into the window
for(size_t stride{0}; stride < dims.at(idx) - dilations.at(idx) * (kernels.at(idx) - 1);
stride += strides.at(idx))
{
for(size_t step{0}; step < kernels.at(idx); ++step)
{
axis_indices.at(idx).push_back(stride + dilations.at(idx) * step);
}
}
}
auto elements = ins->inputs().front();
if(not default_padding)
{
// Pad supports asym, we need to provide both ends
std::vector<size_t> padding(2 * s.lens().size(), 0);
// Format will be e.g {N, C, P1, P2, N, C, P1, P2}
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// Ignore N, C axes
padding.at(2 + idx) = op.padding.at(idx);
padding.at(2 + idx + s.lens().size()) = op.padding.at(idx);
}
// Default value needed for Max pooling
elements = m.insert_instruction(
ins,
make_op("pad", {{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
elements);
}
for(auto idx{0}; idx < axis_indices.size(); ++idx)
{
migraphx::shape s_indices{migraphx::shape::int32_type, {axis_indices.at(idx).size()}};
auto indices = m.add_literal(migraphx::literal{s_indices, axis_indices.at(idx)});
elements = m.insert_instruction(
ins, make_op("gather", {{"axis", idx + 2 /*ignore N,C*/}}), elements, indices);
}
// Ignore padding
std::vector<size_t> new_padding(kernels.size(), 0);
// The kernel window elements are places next to each other. E.g. {x1, y1, x2, y2, ...}
// We need to skip them to not overlap
std::vector<size_t> new_strides(kernels);
// Ignore dilations
std::vector<size_t> new_dilations(kernels.size(), 1);
m.replace_instruction(ins,
make_op("pooling",
{{"mode", op.mode},
{"padding", new_padding},
{"stride", new_strides},
{"lengths", kernels},
{"dilations", new_dilations}}),
elements);
}
void rewrite_pooling::apply(module& m) const void rewrite_pooling::apply(module& m) const
{ {
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
...@@ -43,26 +147,36 @@ void rewrite_pooling::apply(module& m) const ...@@ -43,26 +147,36 @@ void rewrite_pooling::apply(module& m) const
continue; continue;
if(ins->inputs().empty()) if(ins->inputs().empty())
continue; continue;
auto&& s = ins->inputs().front()->get_shape(); auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; })) bool same_kernel_as_shape = std::equal(
continue; s.lens().cbegin() + 2, s.lens().cend(), op.lengths.cbegin(), op.lengths.cend());
if(not std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; })) bool default_strides =
continue; std::all_of(op.stride.cbegin(), op.stride.cend(), [](auto i) { return i == 1; });
auto lens = s.lens(); bool default_padding =
if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end())) std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
continue; bool default_dilations =
std::vector<std::int64_t> axes(lens.size() - 2); std::all_of(op.dilations.cbegin(), op.dilations.cend(), [](auto i) { return i == 1; });
std::iota(axes.begin(), axes.end(), 2); if(same_kernel_as_shape and default_strides and default_padding and default_dilations)
// average pooling
if(op.mode == op::pooling_mode::average)
{ {
m.replace_instruction(ins, make_op("reduce_mean", {{"axes", axes}}), ins->inputs()); replace_with_reduce(m, ins);
} }
// max pooling else if(not default_dilations)
else
{ {
m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs()); // Dilated AvgPool with padding is not supported
if(not default_padding and op.mode == op::pooling_mode::average)
{
continue;
}
auto size =
std::accumulate(s.lens().cbegin(), s.lens().cend(), 1, std::multiplies<size_t>());
// Can't handle too much size because of literal size
if(size > 100000)
{
continue;
}
replace_dilations_with_gather_pooling(m, ins);
} }
} }
} }
......
...@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x); ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
} }
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale); auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div); auto add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);
if(ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
...@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point); add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
} }
int64_t max_quant = 0; double max_quant = 0;
int64_t min_quant = 0; double min_quant = 0;
ins->get_shape().visit_type([&](auto qt) { ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max(); max_quant = qt.max();
min_quant = qt.min(); min_quant = qt.min();
...@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{})) if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{ {
std::vector<int> min_data(s.elements(), min_quant); std::vector<double> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant); std::vector<double> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data)); min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data)); max_arg = m.add_literal(literal(s, max_data));
} }
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp> #include <migraphx/iterator.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/simple_par_for.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/dom_info.hpp> #include <migraphx/dom_info.hpp>
...@@ -461,7 +461,7 @@ struct stream_info ...@@ -461,7 +461,7 @@ struct stream_info
std::back_inserter(index_to_ins), std::back_inserter(index_to_ins),
[](auto&& it) { return it.first; }); [](auto&& it) { return it.first; });
par_for(concur_ins.size(), [&](auto ins_index, auto tid) { simple_par_for(concur_ins.size(), [&](auto ins_index, auto tid) {
auto merge_first = index_to_ins[ins_index]; auto merge_first = index_to_ins[ins_index];
assert(concur_ins.count(merge_first) > 0); assert(concur_ins.count(merge_first) > 0);
auto& merge_second = concur_ins.at(merge_first); auto& merge_second = concur_ins.at(merge_first);
......
...@@ -941,15 +941,6 @@ struct find_splits ...@@ -941,15 +941,6 @@ struct find_splits
{ {
auto split = i->inputs()[split_idx]; auto split = i->inputs()[split_idx];
assert(split->name() == "slice"); assert(split->name() == "slice");
// Insert contiguous for reshapes
auto outputs = i->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), i);
m.replace_instruction(output, output->get_operator(), x);
}
m.replace_instruction(i, split->get_operator(), c); m.replace_instruction(i, split->get_operator(), c);
} }
...@@ -1181,13 +1172,6 @@ struct find_conv_dot_horiz_fusion ...@@ -1181,13 +1172,6 @@ struct find_conv_dot_horiz_fusion
for(auto arg : range(start, last)) for(auto arg : range(start, last))
{ {
auto outputs = arg->outputs(); auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
int64_t len = arg->get_shape().lens()[axis]; int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction( m.replace_instruction(
...@@ -1487,11 +1471,6 @@ struct find_split_reshape ...@@ -1487,11 +1471,6 @@ struct find_split_reshape
slc_axis_len; slc_axis_len;
}); });
// insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
{
input = m.insert_instruction(std::next(input), make_op("contiguous"), input);
}
auto rsp_ins = m.insert_instruction( auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
......
...@@ -22,8 +22,10 @@ ...@@ -22,8 +22,10 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/simplify_dyn_ops.hpp> #include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -32,6 +34,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -32,6 +34,10 @@ inline namespace MIGRAPHX_INLINE_NS {
* Convert 2 input static shape broadcast/multibroadcast into 1 input version. * Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions * Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators. * of the broadcasting operators.
* From:
* broadcast_op(argument_with_static_shape, argument_with_static_shape)
* To:
* broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims
*/ */
struct find_static_2in_broadcasts struct find_static_2in_broadcasts
{ {
...@@ -60,8 +66,65 @@ struct find_static_2in_broadcasts ...@@ -60,8 +66,65 @@ struct find_static_2in_broadcasts
}; };
/** /**
* Simplify slice with variable `starts` and `ends` to the constant version if * Simplify slice with 2 inputs to the 1 input version if inputs[1] is constant.
* the `input_starts` and `input_ends` inputs are constant. * From:
* slice(data, constant_input); two attributes set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_2in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(2), match::arg(1)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
auto slice_op = any_cast<op::slice>(ins->get_operator());
auto set_attrs = slice_op.get_set_attributes();
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
if(set_attrs == op::slice::ends_axes)
{
// slice(data, starts)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_vec = slice_op.ends;
axes_vec = slice_op.axes;
}
else if(set_attrs == op::slice::starts_axes)
{
// slice(data, ends)
inputs.at(1)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
axes_vec = slice_op.axes;
}
else
{
// slice(data, axes)
inputs.at(1)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
ends_vec = slice_op.ends;
}
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
};
/**
* Simplify slice with 3 inputs to the 1 input version if inputs[1:2] are constant.
* From:
* slice(data, constant_input1, constant_input2); one attribute set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/ */
struct find_const_3in_slice struct find_const_3in_slice
{ {
...@@ -76,27 +139,51 @@ struct find_const_3in_slice ...@@ -76,27 +139,51 @@ struct find_const_3in_slice
{ {
auto ins = mr.result; auto ins = mr.result;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval(); auto slice_op = any_cast<op::slice>(ins->get_operator());
argument ends_arg = inputs.at(2)->eval(); auto set_attrs = slice_op.get_set_attributes();
if(not starts_arg.empty() and not ends_arg.empty()) std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
if(set_attrs == op::slice::axes_only)
{ {
std::vector<int64_t> starts_vec; // slice(data, starts, ends)
std::vector<int64_t> ends_vec; inputs.at(1)->eval().visit(
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); }); [&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); }); inputs.at(2)->eval().visit(
auto slice_val = ins->get_operator().to_value(); [&](auto output) { ends_vec.assign(output.begin(), output.end()); });
auto axes_vec = slice_val.at("axes").to_vector<int64_t>(); axes_vec = slice_op.axes;
m.replace_instruction( }
ins, else if(set_attrs == op::slice::ends_only)
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}), {
inputs.at(0)); // slice(data, starts, axes)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
ends_vec = slice_op.ends;
}
else
{
// slice(data, ends, axes)
inputs.at(1)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
} }
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
} }
}; };
/** /**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if * Simplify slice with 4 inputs to the 1 input version if inputs[1:3] are constant.
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant. * From:
* slice(data, constant_starts, constant_ends, constant_axes)
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/ */
struct find_const_4in_slice struct find_const_4in_slice
{ {
...@@ -112,9 +199,9 @@ struct find_const_4in_slice ...@@ -112,9 +199,9 @@ struct find_const_4in_slice
{ {
auto ins = mr.result; auto ins = mr.result;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval(); argument starts_arg = inputs.at(1)->eval(false);
argument ends_arg = inputs.at(2)->eval(); argument ends_arg = inputs.at(2)->eval(false);
argument axes_arg = inputs.at(3)->eval(); argument axes_arg = inputs.at(3)->eval(false);
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty()) if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
{ {
std::vector<int64_t> starts_vec; std::vector<int64_t> starts_vec;
...@@ -131,10 +218,116 @@ struct find_const_4in_slice ...@@ -131,10 +218,116 @@ struct find_const_4in_slice
} }
}; };
/**
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct find_static_dimensions_of
{
auto matcher() const { return match::name("dimensions_of")(); }
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto input = ins->inputs().at(0);
auto dimensions_of_value = ins->get_operator().to_value();
auto start = dimensions_of_value.at("start").to<std::size_t>();
auto end = dimensions_of_value.at("end").to<std::size_t>();
if(input->get_shape().dynamic())
{
// check if dynamic dimensions from start to end are fixed
auto dds = input->get_shape().dyn_dims();
if(std::any_of(dds.begin() + start, dds.begin() + end, [](auto dd) {
return not dd.is_fixed();
}))
{
return;
}
}
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input->get_shape().to_static(1).lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}};
auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape});
m.replace_instruction(ins, lit_ins);
}
};
/**
* Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1
* argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes.
* This matcher can be generalized to matching reshape(data, static_shape_output_tensor).
* From:
* x = allocate(constant_output_dims) -> reshape(data, x)
* To:
* reshape(data); reshape.dims = constant_output_dims
*/
struct find_const_alloc_reshapes
{
auto matcher() const
{
return match::name("reshape")(match::nargs(2),
match::arg(1)(match::name("allocate")(match::is_constant())));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto reshape_ins = mr.result;
auto reshape_inputs = reshape_ins->inputs();
auto alloc_ins = reshape_inputs.at(1);
argument output_dims_arg = alloc_ins->inputs().at(0)->eval(false);
std::vector<int64_t> output_dims_vec;
output_dims_arg.visit(
[&](auto output) { output_dims_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
reshape_ins, make_op("reshape", {{"dims", output_dims_vec}}), reshape_inputs.at(0));
// have dead_code_elimination remove the previous allocate
}
};
/**
* Simplify allocate into fill operator that has constant output dimensions and constant value.
* The allocate into fill instructions is what is produced when parsing the ONNX
* ConstantOfShape operator. This replacement could be handled with propagate_constant, but
* would rather have the simplification happen earlier during compiling.
* This matcher can be generalized to matching fill(constant_value, static_shape_output_tensor).
* From:
* x = allocate(constant_ouptut_dims) -> fill(constant_value, x)
* To:
* literal
*/
struct find_const_alloc_fill
{
auto matcher() const
{
return match::name("fill")(match::arg(0)(match::is_constant()),
match::arg(1)(match::name("allocate")(match::is_constant())));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto fill_ins = mr.result;
auto fill_arg = fill_ins->eval(false);
auto l = m.add_literal(fill_arg.get_shape(), fill_arg.data());
m.replace_instruction(fill_ins, l);
}
};
void simplify_dyn_ops::apply(module& m) const void simplify_dyn_ops::apply(module& m) const
{ {
match::find_matches( match::find_matches(m,
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{}); find_static_dimensions_of{},
find_const_alloc_reshapes{},
find_static_2in_broadcasts{},
find_const_2in_slice{},
find_const_3in_slice{},
find_const_4in_slice{},
find_const_alloc_fill{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -45,77 +45,149 @@ std::unordered_set<std::string> get_quantizable_op_names() ...@@ -45,77 +45,149 @@ std::unordered_set<std::string> get_quantizable_op_names()
return s; return s;
} }
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins) struct match_find_quantizable_ops
{ {
if(ins->name() != "@literal") static bool
return false; is_valid_scale(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
bool all_same = false; {
ins->get_literal().visit([&](auto s) { return scale->get_shape().scalar() or scale->get_shape().elements() == lens.at(axis);
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) { }
return float_equal(scale, s.front());
static bool is_valid_zero_point(instruction_ref zp)
{
if(not zp->can_eval())
return false;
bool all_zeros = false;
zp->eval().visit([&](auto z) {
all_zeros =
std::all_of(z.begin(), z.end(), [&](auto val) { return float_equal(val, 0); });
}); });
}); return all_zeros;
return all_same; }
}
struct match_find_quantizable_ops static auto
{ scale_broadcast_op(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
if(scale->get_shape().scalar())
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
}
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop_arg)
{
auto prev_ins = qop_arg;
std::vector<instruction_ref> ins_inbetween;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
// instructions
while(prev_ins != dqins)
{
ins_inbetween.push_back(prev_ins);
prev_ins = prev_ins->inputs().front();
}
auto qinp = dqins->inputs().front();
for(auto ins : reverse_iterator_for(ins_inbetween))
{
qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp});
}
return qinp;
}
static auto dequantizelinear_op(const std::string& name, const std::string& scale) static auto dequantizelinear_op(const std::string& scale, const std::string& zp)
{ {
return match::name("dequantizelinear")( return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))), match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())),
match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))), match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0))))); match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp))));
} }
auto matcher() const auto matcher() const
{ {
return match::name(get_quantizable_op_names())( return match::name(get_quantizable_op_names())(
match::arg(0)(dequantizelinear_op("x1", "scale1")), match::arg(0)(match::skip_broadcasts_transposes_contiguous(
match::arg(1)(dequantizelinear_op("x2", "scale2"))); dequantizelinear_op("scale1", "zp1").bind("dq1"))),
match::arg(1)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale2", "zp2").bind("dq2"))));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto qop = r.result; auto qop = r.result;
auto q1 = r.instructions["x1"]; auto dq1 = r.instructions["dq1"];
auto q2 = r.instructions["x2"]; auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"]; auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"]; auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 or FP8 type currently supported
std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::int8_type};
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return;
// Only INT8 type currently supported // Only symmetric quantization supported (ie. non-zero zero_points not allowed)
if(q1->get_shape().type() != migraphx::shape::int8_type or if(not(is_valid_zero_point(zp1) and is_valid_zero_point(zp2)))
q2->get_shape().type() != migraphx::shape::int8_type)
return; return;
double scale; // Only support scalar and 1D scales
visit_all(scale1->get_literal(), scale2->get_literal())( if(scale1->get_shape().lens().size() != 1 or scale2->get_shape().lens().size() != 1)
[&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); }); return;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs(); auto qop_args = qop->inputs();
qop_args.at(0) = q1; qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0]);
qop_args.at(1) = q2; qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1]);
instruction_ref dq; instruction_ref dq;
instruction_ref dq_scale; instruction_ref out_scale;
instruction_ref zero_point; instruction_ref zero_point;
if(qop->name() == "convolution") if(qop->name() == "convolution")
{ {
auto conv_val = qop->get_operator().to_value(); auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction( dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args); qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
auto out_lens = dq->get_shape().lens();
// Input scale should always be scalar and weight scale can be scalar or 1D of the
// same lens as the output channel dim (dim 1 in the output)
if(not(is_valid_scale(scale1, out_lens, 1) and is_valid_scale(scale2, out_lens, 1)))
return;
auto s1_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale1, out_lens, 1), scale1);
auto s2_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale2, out_lens, 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
} }
else if(qop->name() == "dot") else if(qop->name() == "dot")
{ {
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args); dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
auto out_lens = dq->get_shape().lens();
// For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M
// for input1 and K for input 2
if(not(is_valid_scale(scale1, out_lens, out_lens.size() - 2) and
is_valid_scale(scale2, out_lens, out_lens.size() - 1)))
return;
auto s1_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale1, out_lens, out_lens.size() - 2), scale1);
auto s2_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale2, out_lens, out_lens.size() - 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
} }
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens(); dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale);
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
m.replace_instruction(qop, dq); m.replace_instruction(qop, dq);
} }
}; };
...@@ -138,9 +210,15 @@ bool compare_literals(instruction_ref ins1, instruction_ref ins2) ...@@ -138,9 +210,15 @@ bool compare_literals(instruction_ref ins1, instruction_ref ins2)
bool diff_shapes_equal_vals = false; bool diff_shapes_equal_vals = false;
visit_all(ins1->get_literal(), ins2->get_literal())([&](const auto l1, const auto l2) { visit_all(ins1->get_literal(), ins2->get_literal())([&](const auto l1, const auto l2) {
diff_shapes_equal_vals = diff_shapes_equal_vals =
std::all_of( std::all_of(l1.begin() + 1,
l1.begin() + 1, l1.end(), [&](auto v) { return float_equal(v, l1.front()); }) and l1.end(),
std::all_of(l2.begin(), l2.end(), [&](auto v) { return float_equal(v, l1.front()); }); [&](auto v) {
return ((float_equal(v, l1.front())) or
(std::isinf(l1.front()) and std::isinf(v)));
}) and
std::all_of(l2.begin(), l2.end(), [&](auto v) {
return ((float_equal(v, l1.front())) or (std::isinf(l1.front()) and std::isinf(v)));
});
}); });
return (x == y) or diff_shapes_equal_vals; return (x == y) or diff_shapes_equal_vals;
......
...@@ -103,8 +103,6 @@ struct find_reshaper ...@@ -103,8 +103,6 @@ struct find_reshaper
auto input = mr.instructions["x"]; auto input = mr.instructions["x"];
auto dims = ins->get_shape().lens(); auto dims = ins->get_shape().lens();
if(not input->get_shape().standard())
input = m.insert_instruction(ins, make_op("contiguous"), input);
m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input); m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
} }
}; };
...@@ -185,6 +183,11 @@ struct find_nested_convert ...@@ -185,6 +183,11 @@ struct find_nested_convert
auto x = ins->inputs().front(); auto x = ins->inputs().front();
auto input = x->inputs().front(); auto input = x->inputs().front();
while(input->name() == "convert")
{
input = input->inputs().front();
}
if(ins->get_shape() != input->get_shape()) if(ins->get_shape() != input->get_shape())
return; return;
...@@ -475,9 +478,8 @@ struct find_resize ...@@ -475,9 +478,8 @@ struct find_resize
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp); ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = m.insert_instruction( auto mb_rsp = m.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data); ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = m.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end()); std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb); m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), mb_rsp);
} }
}; };
...@@ -626,9 +628,8 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -626,9 +628,8 @@ struct find_transpose_contiguous_reshaper_unary
auto cont_ins = r.instructions["cont_ins"]; auto cont_ins = r.instructions["cont_ins"];
auto unary_op_name = ins->get_operator().name(); auto unary_op_name = ins->get_operator().name();
auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins); auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins);
auto new_cont_ins = m.insert_instruction(cont_ins, make_op("contiguous"), unary_ins);
// older cont and reshape are removed by deadcode elimination // older cont and reshape are removed by deadcode elimination
m.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins); m.replace_instruction(ins, reshaper_ins->get_operator(), unary_ins);
} }
}; };
...@@ -647,8 +648,8 @@ struct find_broadcast_transpose ...@@ -647,8 +648,8 @@ struct find_broadcast_transpose
{ {
auto transpose = r.result; auto transpose = r.result;
auto transpose_lens = transpose->get_shape().lens(); auto transpose_lens = transpose->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"]; auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front(); auto input = bcast_ins->inputs().front();
// scalar transformation does not need extra transpose // scalar transformation does not need extra transpose
if(not input->get_shape().scalar()) if(not input->get_shape().scalar())
{ {
......
...@@ -74,21 +74,27 @@ if(MIGRAPHX_ENABLE_ZENDNN) ...@@ -74,21 +74,27 @@ if(MIGRAPHX_ENABLE_ZENDNN)
target_link_libraries(migraphx_cpu PRIVATE ${BLIS_LIB}) target_link_libraries(migraphx_cpu PRIVATE ${BLIS_LIB})
target_link_libraries(migraphx_cpu PRIVATE ${ZENDNN_LIB}) target_link_libraries(migraphx_cpu PRIVATE ${ZENDNN_LIB})
else() else()
target_link_libraries(migraphx_cpu PRIVATE DNNL::dnnl) target_link_libraries(migraphx_cpu PUBLIC DNNL::dnnl)
endif() endif()
target_link_libraries(migraphx_cpu PRIVATE migraphx) target_link_libraries(migraphx_cpu PRIVATE migraphx)
migraphx_generate_export_header(migraphx_cpu) migraphx_generate_export_header(migraphx_cpu)
find_package(OpenMP) find_package(OpenMP)
target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX) if(WIN32)
# Add library path to rpath to workaround issues with our broken packages target_link_libraries(migraphx_cpu PUBLIC libomp)
foreach(LIBRARY ${OpenMP_CXX_LIBRARIES}) target_include_directories(migraphx_cpu PUBLIC ${OpenMP_CXX_INCLUDE_DIRS})
if(LIBRARY MATCHES "libomp") target_compile_options(migraphx_cpu PUBLIC ${OpenMP_CXX_FLAGS})
get_filename_component(LIBRARY_PATH "${LIBRARY}" PATH) else()
target_link_libraries(migraphx_cpu PUBLIC -Wl,-rpath=${LIBRARY_PATH} -Wl,-rpath-link=${LIBRARY_PATH}) target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX)
endif() # Add library path to rpath to workaround issues with our broken packages
endforeach() foreach(LIBRARY ${OpenMP_CXX_LIBRARIES})
if(LIBRARY MATCHES "libomp")
get_filename_component(LIBRARY_PATH "${LIBRARY}" PATH)
target_link_libraries(migraphx_cpu PUBLIC -Wl,-rpath=${LIBRARY_PATH} -Wl,-rpath-link=${LIBRARY_PATH})
endif()
endforeach()
endif()
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_cpu TARGETS migraphx_cpu
......
...@@ -68,6 +68,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t) ...@@ -68,6 +68,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t)
case st::int32_type: return dt::s32; case st::int32_type: return dt::s32;
case st::int8_type: return dt::s8; case st::int8_type: return dt::s8;
case st::uint8_type: return dt::u8; case st::uint8_type: return dt::u8;
case st::fp8e4m3fnuz_type: MIGRAPHX_THROW("fp8e4m3fnuz unsupported in DNNL");
default: MIGRAPHX_THROW("Unsupported data type"); default: MIGRAPHX_THROW("Unsupported data type");
} }
} }
......
...@@ -340,7 +340,6 @@ struct cpu_apply ...@@ -340,7 +340,6 @@ struct cpu_apply
{"reduce_min", "reduction_min"}, {"reduce_min", "reduction_min"},
{"reduce_sum", "reduction_sum"}, {"reduce_sum", "reduction_sum"},
}); });
extend_op("concat", "dnnl::concat"); extend_op("concat", "dnnl::concat");
extend_op("contiguous", "dnnl::reorder"); extend_op("contiguous", "dnnl::reorder");
extend_op("convolution", "dnnl::convolution"); extend_op("convolution", "dnnl::convolution");
...@@ -376,6 +375,12 @@ struct cpu_apply ...@@ -376,6 +375,12 @@ struct cpu_apply
// Apply these operators first so the inputs can be const folded // Apply these operators first so the inputs can be const folded
for(auto it : iterator_for(*modl)) for(auto it : iterator_for(*modl))
{ {
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
continue;
if(it->name() == "pow") if(it->name() == "pow")
{ {
apply_pow(it); apply_pow(it);
...@@ -383,6 +388,12 @@ struct cpu_apply ...@@ -383,6 +388,12 @@ struct cpu_apply
} }
for(auto it : iterator_for(*modl)) for(auto it : iterator_for(*modl))
{ {
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
continue;
if(it->name() == "pooling") if(it->name() == "pooling")
{ {
apply_pooling(it); apply_pooling(it);
......
...@@ -34,23 +34,32 @@ namespace migraphx { ...@@ -34,23 +34,32 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling> struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_v2_forward, op::pooling>
{ {
std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; } std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; }
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const dnnl::pooling_v2_forward::desc
get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg; : dnnl::algorithm::pooling_avg;
auto kdims = op.kdims(); auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims); std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end()); std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
// Note: It is not documented, but the default dilation seems to be 0 instead of 1.
// We need to offset dilations with -1.
std::vector<size_t> dilations;
std::transform(op.dilations.cbegin(),
op.dilations.cend(),
std::back_inserter(dilations),
[](size_t d) { return d - 1; });
return {dnnl::prop_kind::forward_inference, return {dnnl::prop_kind::forward_inference,
algo, algo,
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride), to_dnnl_dims(op.stride),
to_dnnl_dims(op.lengths), to_dnnl_dims(op.lengths),
to_dnnl_dims(dilations),
to_dnnl_dims(padding_l), to_dnnl_dims(padding_l),
to_dnnl_dims(padding_r)}; to_dnnl_dims(padding_r)};
} }
......
# #################################################################################### # ####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -22,23 +22,22 @@ ...@@ -22,23 +22,22 @@
# THE SOFTWARE. # THE SOFTWARE.
# #################################################################################### # ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm) find_package(hip REQUIRED)
find_package(hip)
if(NOT GPU_TARGETS) if(NOT GPU_TARGETS)
message(FATAL_ERROR "HIP package is broken and has no GPU_TARGETS, please pass -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to cmake to build for your gpu.") set(fatal_msg "HIP package is broken and has no GPU_TARGETS. Please pass GPU_TARGETS to cmake.")
if(NOT WIN32)
set(fatal_msg "${fatal_msg}\nUse -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to build for your GPU.")
endif()
message(FATAL_ERROR ${fatal_msg})
endif() endif()
find_package(miopen) find_package(miopen REQUIRED)
message(STATUS "MIGraphX is using MIOpen")
# rocblas # rocblas
find_package(rocblas REQUIRED PATHS /opt/rocm) find_package(rocblas REQUIRED)
message(STATUS "Build with rocblas") message(STATUS "MIGraphX build with rocBLAS")
if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen")
endif()
if(NOT WIN32) if(MIGRAPHX_USE_COMPOSABLEKERNEL)
# TODO: re-enable when CK is ported to Windows
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library) find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
endif() endif()
...@@ -50,12 +49,11 @@ endif() ...@@ -50,12 +49,11 @@ endif()
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
if(WIN32) if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM KERNEL_FILES list(REMOVE_ITEM KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm.hpp ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck.hpp)
endif() endif()
...@@ -67,8 +65,10 @@ file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/* ...@@ -67,8 +65,10 @@ file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*
add_library(migraphx_device ${DEVICE_GPU_SRCS}) add_library(migraphx_device ${DEVICE_GPU_SRCS})
add_library(compile_for_gpu INTERFACE) add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns) target_compile_features(compile_for_gpu INTERFACE cxx_std_17)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored) target_compile_options(compile_for_gpu INTERFACE -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_options(compile_for_gpu INTERFACE -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
target_link_libraries(compile_for_gpu INTERFACE hip::device)
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE) check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)
if(HAS_HIP_LAMBDA_HOST_DEVICE) if(HAS_HIP_LAMBDA_HOST_DEVICE)
...@@ -103,9 +103,10 @@ rocm_clang_tidy_check(kernel_file_check) ...@@ -103,9 +103,10 @@ rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp) file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(WIN32) if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
# TODO: re-enable when CK is ported to Windows list(REMOVE_ITEM JIT_GPU_SRCS
list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp)
endif() endif()
add_library(migraphx_gpu add_library(migraphx_gpu
...@@ -125,11 +126,8 @@ add_library(migraphx_gpu ...@@ -125,11 +126,8 @@ add_library(migraphx_gpu
fuse_ck.cpp fuse_ck.cpp
fuse_mlir.cpp fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp
gemm_impl.cpp gemm_impl.cpp
hip.cpp hip.cpp
int8_conv_pack.cpp
int8_gemm_pack.cpp
kernel.cpp kernel.cpp
lowering.cpp lowering.cpp
logsoftmax.cpp logsoftmax.cpp
...@@ -140,9 +138,7 @@ add_library(migraphx_gpu ...@@ -140,9 +138,7 @@ add_library(migraphx_gpu
no_device.cpp no_device.cpp
nonzero.cpp nonzero.cpp
pack_args.cpp pack_args.cpp
pack_int8_args.cpp
prefuse_ops.cpp prefuse_ops.cpp
pad.cpp
perfdb.cpp perfdb.cpp
pooling.cpp pooling.cpp
reverse.cpp reverse.cpp
...@@ -170,12 +166,10 @@ endfunction() ...@@ -170,12 +166,10 @@ endfunction()
register_migraphx_gpu_ops(hip_ register_migraphx_gpu_ops(hip_
argmax argmax
argmin argmin
gather
logsoftmax logsoftmax
loop loop
multinomial multinomial
nonzero nonzero
pad
prefix_scan_sum prefix_scan_sum
reverse reverse
scatter scatter
...@@ -184,7 +178,6 @@ register_migraphx_gpu_ops(hip_ ...@@ -184,7 +178,6 @@ register_migraphx_gpu_ops(hip_
register_migraphx_gpu_ops(miopen_ register_migraphx_gpu_ops(miopen_
abs abs
contiguous contiguous
int8_conv_pack
lrn lrn
pooling pooling
) )
...@@ -192,10 +185,6 @@ register_op(migraphx_gpu ...@@ -192,10 +185,6 @@ register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot> OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
...@@ -219,8 +208,10 @@ if(MIGRAPHX_ENABLE_MLIR) ...@@ -219,8 +208,10 @@ if(MIGRAPHX_ENABLE_MLIR)
endif() endif()
if(MIGRAPHX_USE_HIPRTC) if(MIGRAPHX_USE_HIPRTC)
find_package(hiprtc REQUIRED)
message(STATUS "MIGraphX is using hipRTC") message(STATUS "MIGraphX is using hipRTC")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1) target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
target_link_libraries(migraphx_gpu PUBLIC hiprtc::hiprtc)
else() else()
message(STATUS "MIGraphX is using HIP Clang") message(STATUS "MIGraphX is using HIP Clang")
...@@ -229,34 +220,47 @@ else() ...@@ -229,34 +220,47 @@ else()
target_flags(HIP_COMPILER_FLAGS hip::device) target_flags(HIP_COMPILER_FLAGS hip::device)
# Remove cuda arch flags # Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE "--cuda-gpu-arch=[a-z0-9]+ ?" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE "--offload-arch=[a-z0-9:+-]+ ?" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file # Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ") string(APPEND HIP_COMPILER_FLAGS " ")
if(WIN32)
string(REPLACE "\\" "/" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endif()
foreach(_unused RANGE 2) foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach() endforeach()
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}") message(STATUS "Hip compiler flags: \"${HIP_COMPILER_FLAGS}\"")
target_compile_definitions(migraphx_gpu PRIVATE target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}" -DMIGRAPHX_HIP_COMPILER="${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}" -DMIGRAPHX_HIP_COMPILER_FLAGS="${HIP_COMPILER_FLAGS}"
) )
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER) if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER) if(WIN32)
execute_process(COMMAND where ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
else()
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
endif()
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER) string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}") target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_HIP_COMPILER_LAUNCHER="${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif() endif()
endif() endif()
# Check miopen find mode api # Check miopen find mode api
include(CheckLibraryExists) include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION) get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
# rocblas FP8 API
check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
...@@ -279,11 +283,25 @@ else() ...@@ -279,11 +283,25 @@ else()
message(STATUS "MIOpen does not have find mode api") message(STATUS "MIOpen does not have find mode api")
endif() endif()
if(HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphx is using Beta API of rocBLAS")
else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()
if(HAS_ROCBLAS_FP8_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations")
else()
message(STATUS "rocBLAS does not have Fp8 Beta API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(NOT WIN32) if(MIGRAPHX_USE_COMPOSABLEKERNEL)
# TODO: re-enable when CK is ported to Windows
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library) target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1)
endif() endif()
add_subdirectory(driver) add_subdirectory(driver)
......
...@@ -54,6 +54,11 @@ vectorize vectorize::elements(std::size_t axis, ...@@ -54,6 +54,11 @@ vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs, const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes) const std::vector<std::size_t>& sizes)
{ {
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(std::all_of( if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; })) inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis}; return {1, axis};
...@@ -86,6 +91,11 @@ vectorize vectorize::elements(std::size_t axis, ...@@ -86,6 +91,11 @@ vectorize vectorize::elements(std::size_t axis,
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs) vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{ {
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(inputs.empty()) if(inputs.empty())
return {1, axis}; return {1, axis};
std::size_t n = std::max_element(inputs.begin(), std::size_t n = std::max_element(inputs.begin(),
......
...@@ -194,7 +194,7 @@ struct hiprtc_program ...@@ -194,7 +194,7 @@ struct hiprtc_program
}; };
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file> srcs, std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file> srcs,
std::string params, const std::string& params,
const std::string& arch) const std::string& arch)
{ {
hiprtc_program prog(std::move(srcs)); hiprtc_program prog(std::move(srcs));
...@@ -238,8 +238,9 @@ bool hip_has_flags(const std::vector<std::string>& flags) ...@@ -238,8 +238,9 @@ bool hip_has_flags(const std::vector<std::string>& flags)
} }
} }
std::vector<std::vector<char>> std::vector<std::vector<char>> compile_hip_src(const std::vector<src_file>& srcs,
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) const std::string& params,
const std::string& arch)
{ {
std::vector<hiprtc_src_file> hsrcs{srcs.begin(), srcs.end()}; std::vector<hiprtc_src_file> hsrcs{srcs.begin(), srcs.end()};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{})) if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
...@@ -251,10 +252,21 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -251,10 +252,21 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
std::cout << std::string(src.content) << std::endl; std::cout << std::string(src.content) << std::endl;
} }
} }
auto fname = fs::path{"migraphx-hiprtc-driver"};
#ifdef _WIN32
fname.replace_extension(".exe");
#endif
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc); auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver"; auto driver = p.parent_path() / fname;
if(fs::exists(driver)) bool found = fs::exists(driver);
if(not found)
{
driver = p.parent_path().parent_path() / "bin" / fname;
found = fs::exists(driver);
}
if(found)
{ {
value v; value v;
v["srcs"] = to_value(hsrcs); v["srcs"] = to_value(hsrcs);
...@@ -270,13 +282,13 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -270,13 +282,13 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(fs::exists(out)) if(fs::exists(out))
return {read_buffer(out.string())}; return {read_buffer(out.string())};
} }
return compile_hip_src_with_hiprtc(std::move(hsrcs), std::move(params), arch); return compile_hip_src_with_hiprtc(std::move(hsrcs), params, arch);
} }
#else // MIGRAPHX_USE_HIPRTC #else // MIGRAPHX_USE_HIPRTC
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file>, // NOLINT std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file>, // NOLINT
std::string, // NOLINT const std::string&, // NOLINT
const std::string&) const std::string&)
{ {
MIGRAPHX_THROW("Not using hiprtc"); MIGRAPHX_THROW("Not using hiprtc");
...@@ -284,16 +296,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -284,16 +296,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
bool is_hip_clang_compiler() bool is_hip_clang_compiler()
{ {
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++"); static const auto result = fs::path{MIGRAPHX_HIP_COMPILER}.stem() == "clang++";
return result; return result;
} }
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
bool has_compiler_launcher() bool has_compiler_launcher()
{ {
static const auto result = fs::exists(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER)); static const auto result = fs::exists(MIGRAPHX_HIP_COMPILER_LAUNCHER);
return result; return result;
} }
#endif
src_compiler assemble(src_compiler compiler) src_compiler assemble(src_compiler compiler)
{ {
compiler.out_ext = ".S"; compiler.out_ext = ".S";
...@@ -301,37 +317,39 @@ src_compiler assemble(src_compiler compiler) ...@@ -301,37 +317,39 @@ src_compiler assemble(src_compiler compiler)
return compiler; return compiler;
} }
std::vector<std::vector<char>> std::vector<std::vector<char>> compile_hip_src(const std::vector<src_file>& srcs,
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) const std::string& params,
const std::string& arch)
{ {
assert(not srcs.empty()); assert(not srcs.empty());
if(not is_hip_clang_compiler()) if(not is_hip_clang_compiler())
MIGRAPHX_THROW("Unknown hip compiler: " + MIGRAPHX_THROW("Unknown hip compiler: " MIGRAPHX_HIP_COMPILER);
std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER)));
src_compiler compiler;
compiler.flags = params;
compiler.compiler = MIGRAPHX_HIP_COMPILER;
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_HIP_COMPILER_LAUNCHER;
#endif
if(params.find("-std=") == std::string::npos) if(params.find("-std=") == std::string::npos)
params += " --std=c++17"; compiler.flags += " --std=c++17";
params += " -fno-gpu-rdc"; compiler.flags += " -fno-gpu-rdc";
if(enabled(MIGRAPHX_GPU_DEBUG_SYM{})) if(enabled(MIGRAPHX_GPU_DEBUG_SYM{}))
params += " -g"; compiler.flags += " -g";
params += " -c"; compiler.flags += " -c";
params += " --offload-arch=" + arch; compiler.flags += " --offload-arch=" + arch;
params += " --cuda-device-only"; compiler.flags += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " "; compiler.flags += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
if(enabled(MIGRAPHX_GPU_DEBUG{})) if(enabled(MIGRAPHX_GPU_DEBUG{}))
params += " -DMIGRAPHX_DEBUG"; compiler.flags += " -DMIGRAPHX_DEBUG";
params += " -Wno-unused-command-line-argument -Wno-cuda-compat "; compiler.flags += " -Wno-unused-command-line-argument -Wno-cuda-compat ";
params += MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_FLAGS); compiler.flags += MIGRAPHX_HIP_COMPILER_FLAGS;
src_compiler compiler;
compiler.flags = params;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER);
#endif
if(enabled(MIGRAPHX_GPU_DUMP_SRC{})) if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{ {
for(const auto& src : srcs) for(const auto& src : srcs)
...@@ -354,7 +372,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -354,7 +372,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
bool hip_has_flags(const std::vector<std::string>& flags) bool hip_has_flags(const std::vector<std::string>& flags)
{ {
src_compiler compiler; src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER); compiler.compiler = MIGRAPHX_HIP_COMPILER;
compiler.flags = compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only"; join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
......
...@@ -200,7 +200,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -200,7 +200,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options.params += " " + join_strings(compiler_warnings(), " "); options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0"; options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror"; options.params += " -Werror";
auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name()); auto cos = compile_hip_src(srcs, options.params, get_device_name());
if(cos.size() != 1) if(cos.size() != 1)
MIGRAPHX_THROW("No code object"); MIGRAPHX_THROW("No code object");
return code_object_op{value::binary{cos.front()}, return code_object_op{value::binary{cos.front()},
......
...@@ -60,9 +60,8 @@ struct miopen_op ...@@ -60,9 +60,8 @@ struct miopen_op
}; };
MIGRAPHX_REGISTER_OP(miopen_op); MIGRAPHX_REGISTER_OP(miopen_op);
std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool format) const std::size_t compile_miopen::compile(operation& op, instruction_ref ins) const
{ {
op.from_value({{"int8_x4_format", format}});
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs())); auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
return v.get<std::size_t>("workspace", 0); return v.get<std::size_t>("workspace", 0);
} }
...@@ -70,25 +69,15 @@ std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool for ...@@ -70,25 +69,15 @@ std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool for
void compile_miopen::apply(module& m) const void compile_miopen::apply(module& m) const
{ {
assert(ctx); assert(ctx);
const bool int8_x4_format = get_int8_x4_format(any_cast<migraphx::gpu::context>(*ctx));
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "gpu::miopen_op") if(ins->name() != "gpu::miopen_op")
continue; continue;
auto op = any_cast<miopen_op>(ins->get_operator()).op; auto op = any_cast<miopen_op>(ins->get_operator()).op;
std::size_t ws = 0; std::size_t ws = 0;
try ws = compile(op, ins);
{ auto inputs = ins->inputs();
// for the regular convolution and convolution_backwards, this try would always succeed auto alloc = m.insert_instruction(
ws = compile(op, ins, int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
ws = compile(op, ins, not int8_x4_format);
}
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(
ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}})); ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}}));
inputs.insert(std::prev(inputs.end()), alloc); inputs.insert(std::prev(inputs.end()), alloc);
......
...@@ -168,6 +168,7 @@ struct compile_plan ...@@ -168,6 +168,7 @@ struct compile_plan
} }
const compiled_result& benchmark(problem_cache& pc) const const compiled_result& benchmark(problem_cache& pc) const
{ {
const auto trace_level = value_of(MIGRAPHX_TRACE_BENCHMARKING{});
if(results.empty()) if(results.empty())
MIGRAPHX_THROW("No configs to tune"); MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1) if(results.size() == 1)
...@@ -178,9 +179,10 @@ struct compile_plan ...@@ -178,9 +179,10 @@ struct compile_plan
} }
if(not config) if(not config)
MIGRAPHX_THROW("Multiple kernels without config"); MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs" if(trace_level > 0)
<< std::endl; std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{})) << std::endl;
if(trace_level > 1)
std::cout << "Problem: " << config->problem << std::endl; std::cout << "Problem: " << config->problem << std::endl;
std::vector<double> times; std::vector<double> times;
times.reserve(results.size()); times.reserve(results.size());
...@@ -189,22 +191,23 @@ struct compile_plan ...@@ -189,22 +191,23 @@ struct compile_plan
config->solutions.begin(), config->solutions.begin(),
std::back_inserter(times), std::back_inserter(times),
[&](const auto& cr, const auto& solution) { [&](const auto& cr, const auto& solution) {
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{})) if(trace_level > 1)
std::cout << "Benchmarking solution: " << solution << std::endl; std::cout << "Benchmarking solution: " << solution << std::endl;
if(not cr.has_value()) if(not cr.has_value())
{ {
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{})) if(trace_level > 1)
std::cout << "No binary" << std::endl; std::cout << "No binary" << std::endl;
return std::numeric_limits<double>::max(); return std::numeric_limits<double>::max();
} }
auto t = time_op( auto t = time_op(
*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20); *ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{})) if(trace_level > 1)
std::cout << t << "ms" << std::endl; std::cout << t << "ms" << std::endl;
return t; return t;
}); });
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl; if(trace_level > 0)
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i)); pc.insert(preop.name(), config->problem, config->solutions.at(i));
if(not results[i].has_value()) if(not results[i].has_value())
MIGRAPHX_THROW("No valid tuned compilation."); MIGRAPHX_THROW("No valid tuned compilation.");
......
...@@ -43,24 +43,32 @@ template <index_int N, ...@@ -43,24 +43,32 @@ template <index_int N,
__device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, Output output) __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, Output output)
{ {
using type = decltype(input(deduce_for_stride(fs))); using type = decltype(input(deduce_for_stride(fs)));
MIGRAPHX_DEVICE_SHARED type buffer[N]; MIGRAPHX_DEVICE_SHARED type buffer[2][N];
type x = init; type x = init;
fs([&](auto i) { fs([&](auto i) {
index_int iout = 0;
index_int iin = 1;
if(idx.local == 0) if(idx.local == 0)
buffer[idx.local] = op(input(i), x); buffer[iout][idx.local] = op(input(i), x);
else else
buffer[idx.local] = input(i); buffer[iout][idx.local] = input(i);
__syncthreads(); __syncthreads();
for(index_int s = 1; s < idx.nlocal(); s *= 2) for(index_int s = 1; s < idx.nlocal(); s *= 2)
{ {
if(idx.local + s < idx.nlocal()) iout = 1 - iout;
iin = 1 - iin;
if(idx.local >= s)
{ {
buffer[idx.local + s] = op(buffer[idx.local], buffer[idx.local + s]); buffer[iout][idx.local] = op(buffer[iin][idx.local], buffer[iin][idx.local - s]);
}
else
{
buffer[iout][idx.local] = buffer[iin][idx.local];
} }
__syncthreads(); __syncthreads();
} }
x = buffer[idx.nlocal() - 1]; x = buffer[iout][idx.nlocal() - 1];
output(i, buffer[idx.local]); output(i, buffer[iout][idx.local]);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment