Unverified Commit 70d9faf7 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into mi200

parents a56c531c a60bdb67
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -25,7 +25,7 @@
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_8bits.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
......@@ -41,8 +41,6 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
static std::vector<shape::type_t>& get_quantizable_type()
{
static std::vector<shape::type_t> quantable_types = {
......@@ -50,7 +48,7 @@ static std::vector<shape::type_t>& get_quantizable_type()
return quantable_types;
}
void quantize_int8_pass::apply(module& m) const // NOLINT
void quantize_8bits_pass::apply(module& m) const // NOLINT
{
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m))
......@@ -66,9 +64,10 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
auto input = ins->inputs().front();
auto s = input->get_shape();
if(contains(quantizable_types, s.type()) and s.type() != shape::int8_type)
if(contains(quantizable_types, s.type()) and s.type() != precision)
{
auto zero_point = m.add_literal(static_cast<int8_t>(param.second));
auto zero_point =
m.add_literal(migraphx::literal{migraphx::shape{precision}, {param.second}});
auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first}));
const auto& lens = s.lens();
scale =
......@@ -87,9 +86,11 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
void capture_arguments_pass::apply(module& m) const // NOLINT
{
assert(param_index != nullptr);
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m))
{
if(not contains(ins_names, ins->name()))
if((not contains(ins_names, ins->name())) or (ins->name() == "convert"))
{
continue;
}
......@@ -97,10 +98,17 @@ void capture_arguments_pass::apply(module& m) const // NOLINT
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
if(contains(quantizable_types, input->get_shape().type()))
{
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
new_args.push_back(new_in);
}
else
{
new_args.push_back(input);
}
}
m.replace_instruction(ins, ins->get_operator(), new_args);
}
}
......
......@@ -56,7 +56,11 @@ target make_target(const std::string& name)
{
if(not contains(target_map(), name))
{
#ifdef _WIN32
std::string target_name = "migraphx_" + name + ".dll";
#else
std::string target_name = "libmigraphx_" + name + ".so";
#endif
store_target_lib(dynamic_loader(target_name));
}
const auto it = target_map().find(name);
......
......@@ -35,25 +35,14 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(module& m) const
static void replace_with_reduce(module& m, instruction_ref ins)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "pooling")
continue;
if(ins->inputs().empty())
continue;
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue;
if(not std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
continue;
auto lens = s.lens();
if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue;
std::vector<std::int64_t> axes(lens.size() - 2);
std::iota(axes.begin(), axes.end(), 2);
// average pooling
if(op.mode == op::pooling_mode::average)
{
......@@ -64,6 +53,131 @@ void rewrite_pooling::apply(module& m) const
{
m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs());
}
}
static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins)
{
// TODO remove this when MIOpen supports dilated pooling
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
// Ignore N, C axes
std::vector<size_t> dims = {s.lens().cbegin() + 2, s.lens().cend()};
bool default_padding =
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
if(not default_padding)
{
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// We need to pad both ends
dims[idx] += op.padding.at(idx) * 2;
}
}
std::vector<size_t> kernels = op.lengths;
std::vector<size_t> strides = op.stride;
std::vector<size_t> dilations = op.dilations;
std::vector<std::vector<int>> axis_indices;
axis_indices.resize(dims.size());
for(auto idx{0}; idx < dims.size(); ++idx)
{
// Only consider if iw fits into the window
for(size_t stride{0}; stride < dims.at(idx) - dilations.at(idx) * (kernels.at(idx) - 1);
stride += strides.at(idx))
{
for(size_t step{0}; step < kernels.at(idx); ++step)
{
axis_indices.at(idx).push_back(stride + dilations.at(idx) * step);
}
}
}
auto elements = ins->inputs().front();
if(not default_padding)
{
// Pad supports asym, we need to provide both ends
std::vector<size_t> padding(2 * s.lens().size(), 0);
// Format will be e.g {N, C, P1, P2, N, C, P1, P2}
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// Ignore N, C axes
padding.at(2 + idx) = op.padding.at(idx);
padding.at(2 + idx + s.lens().size()) = op.padding.at(idx);
}
// Default value needed for Max pooling
elements = m.insert_instruction(
ins,
make_op("pad", {{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
elements);
}
for(auto idx{0}; idx < axis_indices.size(); ++idx)
{
migraphx::shape s_indices{migraphx::shape::int32_type, {axis_indices.at(idx).size()}};
auto indices = m.add_literal(migraphx::literal{s_indices, axis_indices.at(idx)});
elements = m.insert_instruction(
ins, make_op("gather", {{"axis", idx + 2 /*ignore N,C*/}}), elements, indices);
}
// Ignore padding
std::vector<size_t> new_padding(kernels.size(), 0);
// The kernel window elements are places next to each other. E.g. {x1, y1, x2, y2, ...}
// We need to skip them to not overlap
std::vector<size_t> new_strides(kernels);
// Ignore dilations
std::vector<size_t> new_dilations(kernels.size(), 1);
m.replace_instruction(ins,
make_op("pooling",
{{"mode", op.mode},
{"padding", new_padding},
{"stride", new_strides},
{"lengths", kernels},
{"dilations", new_dilations}}),
elements);
}
void rewrite_pooling::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "pooling")
continue;
if(ins->inputs().empty())
continue;
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
bool same_kernel_as_shape = std::equal(
s.lens().cbegin() + 2, s.lens().cend(), op.lengths.cbegin(), op.lengths.cend());
bool default_strides =
std::all_of(op.stride.cbegin(), op.stride.cend(), [](auto i) { return i == 1; });
bool default_padding =
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
bool default_dilations =
std::all_of(op.dilations.cbegin(), op.dilations.cend(), [](auto i) { return i == 1; });
if(same_kernel_as_shape and default_strides and default_padding and default_dilations)
{
replace_with_reduce(m, ins);
}
else if(not default_dilations)
{
// Dilated AvgPool with padding is not supported
if(not default_padding and op.mode == op::pooling_mode::average)
{
continue;
}
auto size =
std::accumulate(s.lens().cbegin(), s.lens().cend(), 1, std::multiplies<size_t>());
// Can't handle too much size because of literal size
if(size > 100000)
{
continue;
}
replace_dilations_with_gather_pooling(m, ins);
}
}
}
......
......@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
auto add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);
if(ins->inputs().size() == 3)
{
......@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
int64_t max_quant = 0;
int64_t min_quant = 0;
double max_quant = 0;
double min_quant = 0;
ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
......@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
std::vector<double> min_data(s.elements(), min_quant);
std::vector<double> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data));
}
......
......@@ -27,7 +27,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/simple_par_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dom_info.hpp>
......@@ -461,7 +461,7 @@ struct stream_info
std::back_inserter(index_to_ins),
[](auto&& it) { return it.first; });
par_for(concur_ins.size(), [&](auto ins_index, auto tid) {
simple_par_for(concur_ins.size(), [&](auto ins_index, auto tid) {
auto merge_first = index_to_ins[ins_index];
assert(concur_ins.count(merge_first) > 0);
auto& merge_second = concur_ins.at(merge_first);
......
......@@ -941,15 +941,6 @@ struct find_splits
{
auto split = i->inputs()[split_idx];
assert(split->name() == "slice");
// Insert contiguous for reshapes
auto outputs = i->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), i);
m.replace_instruction(output, output->get_operator(), x);
}
m.replace_instruction(i, split->get_operator(), c);
}
......@@ -1181,13 +1172,6 @@ struct find_conv_dot_horiz_fusion
for(auto arg : range(start, last))
{
auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction(
......@@ -1487,11 +1471,6 @@ struct find_split_reshape
slc_axis_len;
});
// insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
{
input = m.insert_instruction(std::next(input), make_op("contiguous"), input);
}
auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
......
......@@ -22,8 +22,10 @@
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -32,6 +34,10 @@ inline namespace MIGRAPHX_INLINE_NS {
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
* From:
* broadcast_op(argument_with_static_shape, argument_with_static_shape)
* To:
* broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims
*/
struct find_static_2in_broadcasts
{
......@@ -60,8 +66,65 @@ struct find_static_2in_broadcasts
};
/**
* Simplify slice with variable `starts` and `ends` to the constant version if
* the `input_starts` and `input_ends` inputs are constant.
* Simplify slice with 2 inputs to the 1 input version if inputs[1] is constant.
* From:
* slice(data, constant_input); two attributes set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_2in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(2), match::arg(1)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
auto slice_op = any_cast<op::slice>(ins->get_operator());
auto set_attrs = slice_op.get_set_attributes();
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
if(set_attrs == op::slice::ends_axes)
{
// slice(data, starts)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_vec = slice_op.ends;
axes_vec = slice_op.axes;
}
else if(set_attrs == op::slice::starts_axes)
{
// slice(data, ends)
inputs.at(1)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
axes_vec = slice_op.axes;
}
else
{
// slice(data, axes)
inputs.at(1)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
ends_vec = slice_op.ends;
}
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
};
/**
* Simplify slice with 3 inputs to the 1 input version if inputs[1:2] are constant.
* From:
* slice(data, constant_input1, constant_input2); one attribute set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_3in_slice
{
......@@ -76,27 +139,51 @@ struct find_const_3in_slice
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
if(not starts_arg.empty() and not ends_arg.empty())
{
auto slice_op = any_cast<op::slice>(ins->get_operator());
auto set_attrs = slice_op.get_set_attributes();
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
auto slice_val = ins->get_operator().to_value();
auto axes_vec = slice_val.at("axes").to_vector<int64_t>();
std::vector<int64_t> axes_vec;
if(set_attrs == op::slice::axes_only)
{
// slice(data, starts, ends)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
axes_vec = slice_op.axes;
}
else if(set_attrs == op::slice::ends_only)
{
// slice(data, starts, axes)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
ends_vec = slice_op.ends;
}
else
{
// slice(data, ends, axes)
inputs.at(1)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
}
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
/**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant.
* Simplify slice with 4 inputs to the 1 input version if inputs[1:3] are constant.
* From:
* slice(data, constant_starts, constant_ends, constant_axes)
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_4in_slice
{
......@@ -112,9 +199,9 @@ struct find_const_4in_slice
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
argument axes_arg = inputs.at(3)->eval();
argument starts_arg = inputs.at(1)->eval(false);
argument ends_arg = inputs.at(2)->eval(false);
argument axes_arg = inputs.at(3)->eval(false);
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
{
std::vector<int64_t> starts_vec;
......@@ -131,10 +218,116 @@ struct find_const_4in_slice
}
};
/**
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct find_static_dimensions_of
{
auto matcher() const { return match::name("dimensions_of")(); }
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto input = ins->inputs().at(0);
auto dimensions_of_value = ins->get_operator().to_value();
auto start = dimensions_of_value.at("start").to<std::size_t>();
auto end = dimensions_of_value.at("end").to<std::size_t>();
if(input->get_shape().dynamic())
{
// check if dynamic dimensions from start to end are fixed
auto dds = input->get_shape().dyn_dims();
if(std::any_of(dds.begin() + start, dds.begin() + end, [](auto dd) {
return not dd.is_fixed();
}))
{
return;
}
}
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input->get_shape().to_static(1).lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}};
auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape});
m.replace_instruction(ins, lit_ins);
}
};
/**
* Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1
* argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes.
* This matcher can be generalized to matching reshape(data, static_shape_output_tensor).
* From:
* x = allocate(constant_output_dims) -> reshape(data, x)
* To:
* reshape(data); reshape.dims = constant_output_dims
*/
struct find_const_alloc_reshapes
{
auto matcher() const
{
return match::name("reshape")(match::nargs(2),
match::arg(1)(match::name("allocate")(match::is_constant())));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto reshape_ins = mr.result;
auto reshape_inputs = reshape_ins->inputs();
auto alloc_ins = reshape_inputs.at(1);
argument output_dims_arg = alloc_ins->inputs().at(0)->eval(false);
std::vector<int64_t> output_dims_vec;
output_dims_arg.visit(
[&](auto output) { output_dims_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
reshape_ins, make_op("reshape", {{"dims", output_dims_vec}}), reshape_inputs.at(0));
// have dead_code_elimination remove the previous allocate
}
};
/**
* Simplify allocate into fill operator that has constant output dimensions and constant value.
* The allocate into fill instructions is what is produced when parsing the ONNX
* ConstantOfShape operator. This replacement could be handled with propagate_constant, but
* would rather have the simplification happen earlier during compiling.
* This matcher can be generalized to matching fill(constant_value, static_shape_output_tensor).
* From:
* x = allocate(constant_ouptut_dims) -> fill(constant_value, x)
* To:
* literal
*/
struct find_const_alloc_fill
{
auto matcher() const
{
return match::name("fill")(match::arg(0)(match::is_constant()),
match::arg(1)(match::name("allocate")(match::is_constant())));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto fill_ins = mr.result;
auto fill_arg = fill_ins->eval(false);
auto l = m.add_literal(fill_arg.get_shape(), fill_arg.data());
m.replace_instruction(fill_ins, l);
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{});
match::find_matches(m,
find_static_dimensions_of{},
find_const_alloc_reshapes{},
find_static_2in_broadcasts{},
find_const_2in_slice{},
find_const_3in_slice{},
find_const_4in_slice{},
find_const_alloc_fill{});
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -45,77 +45,149 @@ std::unordered_set<std::string> get_quantizable_op_names()
return s;
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
struct match_find_quantizable_ops
{
if(ins->name() != "@literal")
static bool
is_valid_scale(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
return scale->get_shape().scalar() or scale->get_shape().elements() == lens.at(axis);
}
static bool is_valid_zero_point(instruction_ref zp)
{
if(not zp->can_eval())
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
});
bool all_zeros = false;
zp->eval().visit([&](auto z) {
all_zeros =
std::all_of(z.begin(), z.end(), [&](auto val) { return float_equal(val, 0); });
});
return all_same;
}
return all_zeros;
}
struct match_find_quantizable_ops
{
static auto
scale_broadcast_op(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
if(scale->get_shape().scalar())
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
}
static auto dequantizelinear_op(const std::string& name, const std::string& scale)
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop_arg)
{
auto prev_ins = qop_arg;
std::vector<instruction_ref> ins_inbetween;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
// instructions
while(prev_ins != dqins)
{
ins_inbetween.push_back(prev_ins);
prev_ins = prev_ins->inputs().front();
}
auto qinp = dqins->inputs().front();
for(auto ins : reverse_iterator_for(ins_inbetween))
{
qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp});
}
return qinp;
}
static auto dequantizelinear_op(const std::string& scale, const std::string& zp)
{
return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))),
match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0)))));
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())),
match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp))));
}
auto matcher() const
{
return match::name(get_quantizable_op_names())(
match::arg(0)(dequantizelinear_op("x1", "scale1")),
match::arg(1)(dequantizelinear_op("x2", "scale2")));
match::arg(0)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale1", "zp1").bind("dq1"))),
match::arg(1)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale2", "zp2").bind("dq2"))));
}
void apply(module& m, const match::matcher_result& r) const
{
auto qop = r.result;
auto q1 = r.instructions["x1"];
auto q2 = r.instructions["x2"];
auto dq1 = r.instructions["dq1"];
auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 or FP8 type currently supported
std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::int8_type};
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return;
// Only INT8 type currently supported
if(q1->get_shape().type() != migraphx::shape::int8_type or
q2->get_shape().type() != migraphx::shape::int8_type)
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
if(not(is_valid_zero_point(zp1) and is_valid_zero_point(zp2)))
return;
double scale;
visit_all(scale1->get_literal(), scale2->get_literal())(
[&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); });
// Only support scalar and 1D scales
if(scale1->get_shape().lens().size() != 1 or scale2->get_shape().lens().size() != 1)
return;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs();
qop_args.at(0) = q1;
qop_args.at(1) = q2;
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0]);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1]);
instruction_ref dq;
instruction_ref dq_scale;
instruction_ref out_scale;
instruction_ref zero_point;
if(qop->name() == "convolution")
{
auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
auto out_lens = dq->get_shape().lens();
// Input scale should always be scalar and weight scale can be scalar or 1D of the
// same lens as the output channel dim (dim 1 in the output)
if(not(is_valid_scale(scale1, out_lens, 1) and is_valid_scale(scale2, out_lens, 1)))
return;
auto s1_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale1, out_lens, 1), scale1);
auto s2_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale2, out_lens, 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
}
else if(qop->name() == "dot")
{
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
auto out_lens = dq->get_shape().lens();
// For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M
// for input1 and K for input 2
if(not(is_valid_scale(scale1, out_lens, out_lens.size() - 2) and
is_valid_scale(scale2, out_lens, out_lens.size() - 1)))
return;
auto s1_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale1, out_lens, out_lens.size() - 2), scale1);
auto s2_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale2, out_lens, out_lens.size() - 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
}
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens();
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale);
m.replace_instruction(qop, dq);
}
};
......@@ -138,9 +210,15 @@ bool compare_literals(instruction_ref ins1, instruction_ref ins2)
bool diff_shapes_equal_vals = false;
visit_all(ins1->get_literal(), ins2->get_literal())([&](const auto l1, const auto l2) {
diff_shapes_equal_vals =
std::all_of(
l1.begin() + 1, l1.end(), [&](auto v) { return float_equal(v, l1.front()); }) and
std::all_of(l2.begin(), l2.end(), [&](auto v) { return float_equal(v, l1.front()); });
std::all_of(l1.begin() + 1,
l1.end(),
[&](auto v) {
return ((float_equal(v, l1.front())) or
(std::isinf(l1.front()) and std::isinf(v)));
}) and
std::all_of(l2.begin(), l2.end(), [&](auto v) {
return ((float_equal(v, l1.front())) or (std::isinf(l1.front()) and std::isinf(v)));
});
});
return (x == y) or diff_shapes_equal_vals;
......
......@@ -103,8 +103,6 @@ struct find_reshaper
auto input = mr.instructions["x"];
auto dims = ins->get_shape().lens();
if(not input->get_shape().standard())
input = m.insert_instruction(ins, make_op("contiguous"), input);
m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
}
};
......@@ -185,6 +183,11 @@ struct find_nested_convert
auto x = ins->inputs().front();
auto input = x->inputs().front();
while(input->name() == "convert")
{
input = input->inputs().front();
}
if(ins->get_shape() != input->get_shape())
return;
......@@ -475,9 +478,8 @@ struct find_resize
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = m.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = m.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), mb_rsp);
}
};
......@@ -626,9 +628,8 @@ struct find_transpose_contiguous_reshaper_unary
auto cont_ins = r.instructions["cont_ins"];
auto unary_op_name = ins->get_operator().name();
auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins);
auto new_cont_ins = m.insert_instruction(cont_ins, make_op("contiguous"), unary_ins);
// older cont and reshape are removed by deadcode elimination
m.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins);
m.replace_instruction(ins, reshaper_ins->get_operator(), unary_ins);
}
};
......
......@@ -74,21 +74,27 @@ if(MIGRAPHX_ENABLE_ZENDNN)
target_link_libraries(migraphx_cpu PRIVATE ${BLIS_LIB})
target_link_libraries(migraphx_cpu PRIVATE ${ZENDNN_LIB})
else()
target_link_libraries(migraphx_cpu PRIVATE DNNL::dnnl)
target_link_libraries(migraphx_cpu PUBLIC DNNL::dnnl)
endif()
target_link_libraries(migraphx_cpu PRIVATE migraphx)
migraphx_generate_export_header(migraphx_cpu)
find_package(OpenMP)
target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX)
# Add library path to rpath to workaround issues with our broken packages
foreach(LIBRARY ${OpenMP_CXX_LIBRARIES})
if(WIN32)
target_link_libraries(migraphx_cpu PUBLIC libomp)
target_include_directories(migraphx_cpu PUBLIC ${OpenMP_CXX_INCLUDE_DIRS})
target_compile_options(migraphx_cpu PUBLIC ${OpenMP_CXX_FLAGS})
else()
target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX)
# Add library path to rpath to workaround issues with our broken packages
foreach(LIBRARY ${OpenMP_CXX_LIBRARIES})
if(LIBRARY MATCHES "libomp")
get_filename_component(LIBRARY_PATH "${LIBRARY}" PATH)
target_link_libraries(migraphx_cpu PUBLIC -Wl,-rpath=${LIBRARY_PATH} -Wl,-rpath-link=${LIBRARY_PATH})
endif()
endforeach()
endforeach()
endif()
rocm_install_targets(
TARGETS migraphx_cpu
......
......@@ -68,6 +68,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t)
case st::int32_type: return dt::s32;
case st::int8_type: return dt::s8;
case st::uint8_type: return dt::u8;
case st::fp8e4m3fnuz_type: MIGRAPHX_THROW("fp8e4m3fnuz unsupported in DNNL");
default: MIGRAPHX_THROW("Unsupported data type");
}
}
......
......@@ -340,7 +340,6 @@ struct cpu_apply
{"reduce_min", "reduction_min"},
{"reduce_sum", "reduction_sum"},
});
extend_op("concat", "dnnl::concat");
extend_op("contiguous", "dnnl::reorder");
extend_op("convolution", "dnnl::convolution");
......@@ -376,6 +375,12 @@ struct cpu_apply
// Apply these operators first so the inputs can be const folded
for(auto it : iterator_for(*modl))
{
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
continue;
if(it->name() == "pow")
{
apply_pow(it);
......@@ -383,6 +388,12 @@ struct cpu_apply
}
for(auto it : iterator_for(*modl))
{
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
continue;
if(it->name() == "pooling")
{
apply_pooling(it);
......
......@@ -34,23 +34,32 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling>
struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_v2_forward, op::pooling>
{
std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; }
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
dnnl::pooling_v2_forward::desc
get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg;
auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
// Note: It is not documented, but the default dilation seems to be 0 instead of 1.
// We need to offset dilations with -1.
std::vector<size_t> dilations;
std::transform(op.dilations.cbegin(),
op.dilations.cend(),
std::back_inserter(dilations),
[](size_t d) { return d - 1; });
return {dnnl::prop_kind::forward_inference,
algo,
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride),
to_dnnl_dims(op.lengths),
to_dnnl_dims(dilations),
to_dnnl_dims(padding_l),
to_dnnl_dims(padding_r)};
}
......
# ####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -22,23 +22,22 @@
# THE SOFTWARE.
# ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
find_package(hip REQUIRED)
if(NOT GPU_TARGETS)
message(FATAL_ERROR "HIP package is broken and has no GPU_TARGETS, please pass -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to cmake to build for your gpu.")
set(fatal_msg "HIP package is broken and has no GPU_TARGETS. Please pass GPU_TARGETS to cmake.")
if(NOT WIN32)
set(fatal_msg "${fatal_msg}\nUse -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to build for your GPU.")
endif()
message(FATAL_ERROR ${fatal_msg})
endif()
find_package(miopen)
find_package(miopen REQUIRED)
message(STATUS "MIGraphX is using MIOpen")
# rocblas
find_package(rocblas REQUIRED PATHS /opt/rocm)
message(STATUS "Build with rocblas")
if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen")
endif()
find_package(rocblas REQUIRED)
message(STATUS "MIGraphX build with rocBLAS")
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
endif()
......@@ -50,12 +49,11 @@ endif()
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
if(WIN32)
# TODO: re-enable when CK is ported to Windows
if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
list(REMOVE_ITEM KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck.hpp)
endif()
......@@ -67,8 +65,10 @@ file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*
add_library(migraphx_device ${DEVICE_GPU_SRCS})
add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
target_compile_features(compile_for_gpu INTERFACE cxx_std_17)
target_compile_options(compile_for_gpu INTERFACE -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_options(compile_for_gpu INTERFACE -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
target_link_libraries(compile_for_gpu INTERFACE hip::device)
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)
if(HAS_HIP_LAMBDA_HOST_DEVICE)
......@@ -103,9 +103,10 @@ rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(WIN32)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp)
if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
list(REMOVE_ITEM JIT_GPU_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp)
endif()
add_library(migraphx_gpu
......@@ -125,11 +126,8 @@ add_library(migraphx_gpu
fuse_ck.cpp
fuse_mlir.cpp
fuse_ops.cpp
gather.cpp
gemm_impl.cpp
hip.cpp
int8_conv_pack.cpp
int8_gemm_pack.cpp
kernel.cpp
lowering.cpp
logsoftmax.cpp
......@@ -140,9 +138,7 @@ add_library(migraphx_gpu
no_device.cpp
nonzero.cpp
pack_args.cpp
pack_int8_args.cpp
prefuse_ops.cpp
pad.cpp
perfdb.cpp
pooling.cpp
reverse.cpp
......@@ -170,12 +166,10 @@ endfunction()
register_migraphx_gpu_ops(hip_
argmax
argmin
gather
logsoftmax
loop
multinomial
nonzero
pad
prefix_scan_sum
reverse
scatter
......@@ -184,7 +178,6 @@ register_migraphx_gpu_ops(hip_
register_migraphx_gpu_ops(miopen_
abs
contiguous
int8_conv_pack
lrn
pooling
)
......@@ -192,10 +185,6 @@ register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
......@@ -219,8 +208,10 @@ if(MIGRAPHX_ENABLE_MLIR)
endif()
if(MIGRAPHX_USE_HIPRTC)
find_package(hiprtc REQUIRED)
message(STATUS "MIGraphX is using hipRTC")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
target_link_libraries(migraphx_gpu PUBLIC hiprtc::hiprtc)
else()
message(STATUS "MIGraphX is using HIP Clang")
......@@ -229,34 +220,47 @@ else()
target_flags(HIP_COMPILER_FLAGS hip::device)
# Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE "--cuda-gpu-arch=[a-z0-9]+ ?" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE "--offload-arch=[a-z0-9:+-]+ ?" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ")
if(WIN32)
string(REPLACE "\\" "/" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endif()
foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach()
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
message(STATUS "Hip compiler flags: \"${HIP_COMPILER_FLAGS}\"")
target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
-DMIGRAPHX_HIP_COMPILER="${CMAKE_CXX_COMPILER}"
-DMIGRAPHX_HIP_COMPILER_FLAGS="${HIP_COMPILER_FLAGS}"
)
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
if(WIN32)
execute_process(COMMAND where ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
else()
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
endif()
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_HIP_COMPILER_LAUNCHER="${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
endif()
# Check miopen find mode api
include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
# rocblas FP8 API
check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
......@@ -279,11 +283,25 @@ else()
message(STATUS "MIOpen does not have find mode api")
endif()
if(HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphx is using Beta API of rocBLAS")
else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()
if(HAS_ROCBLAS_FP8_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations")
else()
message(STATUS "rocBLAS does not have Fp8 Beta API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1)
endif()
add_subdirectory(driver)
......
......@@ -54,6 +54,11 @@ vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes)
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis};
......@@ -86,6 +91,11 @@ vectorize vectorize::elements(std::size_t axis,
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(inputs.empty())
return {1, axis};
std::size_t n = std::max_element(inputs.begin(),
......
......@@ -194,7 +194,7 @@ struct hiprtc_program
};
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file> srcs,
std::string params,
const std::string& params,
const std::string& arch)
{
hiprtc_program prog(std::move(srcs));
......@@ -238,8 +238,9 @@ bool hip_has_flags(const std::vector<std::string>& flags)
}
}
std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
std::vector<std::vector<char>> compile_hip_src(const std::vector<src_file>& srcs,
const std::string& params,
const std::string& arch)
{
std::vector<hiprtc_src_file> hsrcs{srcs.begin(), srcs.end()};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
......@@ -251,10 +252,21 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
std::cout << std::string(src.content) << std::endl;
}
}
auto fname = fs::path{"migraphx-hiprtc-driver"};
#ifdef _WIN32
fname.replace_extension(".exe");
#endif
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver";
auto driver = p.parent_path() / fname;
bool found = fs::exists(driver);
if(not found)
{
driver = p.parent_path().parent_path() / "bin" / fname;
found = fs::exists(driver);
}
if(fs::exists(driver))
if(found)
{
value v;
v["srcs"] = to_value(hsrcs);
......@@ -270,13 +282,13 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(fs::exists(out))
return {read_buffer(out.string())};
}
return compile_hip_src_with_hiprtc(std::move(hsrcs), std::move(params), arch);
return compile_hip_src_with_hiprtc(std::move(hsrcs), params, arch);
}
#else // MIGRAPHX_USE_HIPRTC
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file>, // NOLINT
std::string, // NOLINT
const std::string&, // NOLINT
const std::string&)
{
MIGRAPHX_THROW("Not using hiprtc");
......@@ -284,16 +296,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
bool is_hip_clang_compiler()
{
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++");
static const auto result = fs::path{MIGRAPHX_HIP_COMPILER}.stem() == "clang++";
return result;
}
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
bool has_compiler_launcher()
{
static const auto result = fs::exists(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER));
static const auto result = fs::exists(MIGRAPHX_HIP_COMPILER_LAUNCHER);
return result;
}
#endif
src_compiler assemble(src_compiler compiler)
{
compiler.out_ext = ".S";
......@@ -301,37 +317,39 @@ src_compiler assemble(src_compiler compiler)
return compiler;
}
std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
std::vector<std::vector<char>> compile_hip_src(const std::vector<src_file>& srcs,
const std::string& params,
const std::string& arch)
{
assert(not srcs.empty());
if(not is_hip_clang_compiler())
MIGRAPHX_THROW("Unknown hip compiler: " +
std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER)));
MIGRAPHX_THROW("Unknown hip compiler: " MIGRAPHX_HIP_COMPILER);
src_compiler compiler;
compiler.flags = params;
compiler.compiler = MIGRAPHX_HIP_COMPILER;
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_HIP_COMPILER_LAUNCHER;
#endif
if(params.find("-std=") == std::string::npos)
params += " --std=c++17";
params += " -fno-gpu-rdc";
compiler.flags += " --std=c++17";
compiler.flags += " -fno-gpu-rdc";
if(enabled(MIGRAPHX_GPU_DEBUG_SYM{}))
params += " -g";
params += " -c";
params += " --offload-arch=" + arch;
params += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
compiler.flags += " -g";
compiler.flags += " -c";
compiler.flags += " --offload-arch=" + arch;
compiler.flags += " --cuda-device-only";
compiler.flags += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
if(enabled(MIGRAPHX_GPU_DEBUG{}))
params += " -DMIGRAPHX_DEBUG";
compiler.flags += " -DMIGRAPHX_DEBUG";
params += " -Wno-unused-command-line-argument -Wno-cuda-compat ";
params += MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_FLAGS);
compiler.flags += " -Wno-unused-command-line-argument -Wno-cuda-compat ";
compiler.flags += MIGRAPHX_HIP_COMPILER_FLAGS;
src_compiler compiler;
compiler.flags = params;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER);
#endif
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
for(const auto& src : srcs)
......@@ -354,7 +372,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
bool hip_has_flags(const std::vector<std::string>& flags)
{
src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.compiler = MIGRAPHX_HIP_COMPILER;
compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
......
......@@ -200,7 +200,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror";
auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name());
auto cos = compile_hip_src(srcs, options.params, get_device_name());
if(cos.size() != 1)
MIGRAPHX_THROW("No code object");
return code_object_op{value::binary{cos.front()},
......
......@@ -60,9 +60,8 @@ struct miopen_op
};
MIGRAPHX_REGISTER_OP(miopen_op);
std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool format) const
std::size_t compile_miopen::compile(operation& op, instruction_ref ins) const
{
op.from_value({{"int8_x4_format", format}});
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
return v.get<std::size_t>("workspace", 0);
}
......@@ -70,23 +69,13 @@ std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool for
void compile_miopen::apply(module& m) const
{
assert(ctx);
const bool int8_x4_format = get_int8_x4_format(any_cast<migraphx::gpu::context>(*ctx));
for(auto ins : iterator_for(m))
{
if(ins->name() != "gpu::miopen_op")
continue;
auto op = any_cast<miopen_op>(ins->get_operator()).op;
std::size_t ws = 0;
try
{
// for the regular convolution and convolution_backwards, this try would always succeed
ws = compile(op, ins, int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
ws = compile(op, ins, not int8_x4_format);
}
ws = compile(op, ins);
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(
ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}}));
......
......@@ -168,6 +168,7 @@ struct compile_plan
}
const compiled_result& benchmark(problem_cache& pc) const
{
const auto trace_level = value_of(MIGRAPHX_TRACE_BENCHMARKING{});
if(results.empty())
MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1)
......@@ -178,9 +179,10 @@ struct compile_plan
}
if(not config)
MIGRAPHX_THROW("Multiple kernels without config");
if(trace_level > 0)
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
if(trace_level > 1)
std::cout << "Problem: " << config->problem << std::endl;
std::vector<double> times;
times.reserve(results.size());
......@@ -189,21 +191,22 @@ struct compile_plan
config->solutions.begin(),
std::back_inserter(times),
[&](const auto& cr, const auto& solution) {
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
if(trace_level > 1)
std::cout << "Benchmarking solution: " << solution << std::endl;
if(not cr.has_value())
{
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
if(trace_level > 1)
std::cout << "No binary" << std::endl;
return std::numeric_limits<double>::max();
}
auto t = time_op(
*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
if(trace_level > 1)
std::cout << t << "ms" << std::endl;
return t;
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
if(trace_level > 0)
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i));
if(not results[i].has_value())
......
......@@ -43,24 +43,32 @@ template <index_int N,
__device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, Output output)
{
using type = decltype(input(deduce_for_stride(fs)));
MIGRAPHX_DEVICE_SHARED type buffer[N];
MIGRAPHX_DEVICE_SHARED type buffer[2][N];
type x = init;
fs([&](auto i) {
index_int iout = 0;
index_int iin = 1;
if(idx.local == 0)
buffer[idx.local] = op(input(i), x);
buffer[iout][idx.local] = op(input(i), x);
else
buffer[idx.local] = input(i);
buffer[iout][idx.local] = input(i);
__syncthreads();
for(index_int s = 1; s < idx.nlocal(); s *= 2)
{
if(idx.local + s < idx.nlocal())
iout = 1 - iout;
iin = 1 - iin;
if(idx.local >= s)
{
buffer[idx.local + s] = op(buffer[idx.local], buffer[idx.local + s]);
buffer[iout][idx.local] = op(buffer[iin][idx.local], buffer[iin][idx.local - s]);
}
else
{
buffer[iout][idx.local] = buffer[iin][idx.local];
}
__syncthreads();
}
x = buffer[idx.nlocal() - 1];
output(i, buffer[idx.local]);
x = buffer[iout][idx.nlocal() - 1];
output(i, buffer[iout][idx.local]);
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment