Commit 8d7a8a6c authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents 25b33431 a09dc502
......@@ -22,14 +22,8 @@
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -39,76 +33,14 @@ struct parse_pooling : op_parser<parse_pooling>
{
std::vector<op_desc> operators() const
{
return {{"AveragePool", "average"},
return {
{"AveragePool", "average"},
{"GlobalAveragePool", "average"},
{"GlobalMaxPool", "max"},
{"MaxPool", "max"},
{"LpPool", "lpnorm"},
{"GlobalLpPool", "lpnorm"}};
}
value handle_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values) const
{
auto kdims = in_shape.ndim() - 2;
if(starts_with(opd.onnx_name, "Global"))
{
// if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
}
else
{
// works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
}
if(contains(info.attributes, "ceil_mode"))
{
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
if(contains(info.attributes, "dilations"))
{
values["dilations"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilations"]));
check_attr_sizes(
kdims, values["dilations"].size(), "PARSE_POOLING: inconsistent dilations");
}
// lp_order attribute
if(contains(info.attributes, "p"))
{
values["lp_order"] = info.attributes.at("p").i();
}
// ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
return values;
{"GlobalLpPool", "lpnorm"},
};
}
instruction_ref parse(const op_desc& opd,
......@@ -116,148 +48,8 @@ struct parse_pooling : op_parser<parse_pooling>
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::string mode = opd.op_name;
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
}
if(paddings.size() != 2 * kdims)
{
paddings.resize(kdims * 2);
std::fill_n(paddings.begin(), 2 * kdims, 0);
}
if(values["padding"].size() != kdims)
{
values["padding"].resize(kdims);
std::fill_n(values["padding"].begin(), kdims, 0);
}
if(values["stride"].size() != kdims)
{
values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1);
}
if(values["dilations"].size() != kdims)
{
values["dilations"].resize(kdims);
std::fill_n(values["dilations"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings;
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
auto auto_pad = to_upper(info.attributes["auto_pad"].s());
// don't use the given padding sizes, if any
// values["padding"].clear();
if(in_shape.dynamic())
{
// set padding_mode to trigger auto padding at runtime
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
else
{
// Calculate auto padding
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
values["dilations"].to_vector<std::size_t>(),
in_shape.lens(),
paddings);
values["padding"] = paddings;
// default padding_mode indicates that padding sizes are not calculated dynamically
values["padding_mode"] = migraphx::op::padding_mode_t::default_;
}
}
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(not slice_start.empty())
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()});
// make an op just to get its output shape
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information
slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2,
out_lens.end(),
slice_start.begin(),
slice_end.begin(),
[](auto i, auto j) { return i + j; });
}
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
op.from_value(values);
auto l1 = info.add_instruction(op, l0);
if(not slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}),
l1);
}
return l1;
}
return add_pooling_op(opd, std::move(info), args[0]);
};
};
} // namespace onnx
......
......@@ -23,6 +23,7 @@
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/make_op.hpp>
......@@ -36,90 +37,56 @@ namespace onnx {
/*
*********************************************************************************
* Reference: see QLinearGlobalAveragePool in *
* Reference: see QLinearAveragePool and QLinearGlobalAveragePool in *
* github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
*/
QLinearGlobalAveragePool consumes an input tensor X and applies
Average pooling across the values in the same channel. This is
equivalent to AveragePool with kernel size equal to the spatial
dimension of input tensor. Input is of type uint8_t or int8_t.
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
Attributes
channels_last : int
Inputs
X : T
Input data tensor from the previous operator; According to channels_last, dimensions for image case
are (N x C x H x W), or (N x H x W x C) where N is the batch size, C is the number of channels, and
H and W are the height and the width of the data. For non image case, the dimensions are in the form
of (N x C x D1 x D2 ... Dn), or (N x D1 X D2 ... Dn x C) where N is the batch size.
x_scale : tensor(float)
Scale of quantized input 'X'. It must be a scalar.
x_zero_point : T
Zero point tensor for input 'X'. It must be a scalar.
y_scale : tensor(float)
Scale of quantized output 'Y'. It must be a scalar.
y_zero_point : T
Zero point tensor for output 'Y'. It must be a scalar.
Outputs
Y : T
Output data tensor from pooling across the input tensor. The output tensor has the same rank as the
input. with the N and C value keep it value, while the other dimensions are all 1. Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to signed/unsigned int8 tensors.
*/
struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool>
struct parse_qlinearpooling : op_parser<parse_qlinearpooling>
{
std::vector<op_desc> operators() const { return {{"QLinearGlobalAveragePool"}}; }
// basic type checking for QLinearGlobalAveragePool Operator
void check_inputs(const std::vector<instruction_ref>& args) const
std::vector<op_desc> operators() const
{
if(args.size() < 5)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: missing inputs");
return {{"QLinearGlobalAveragePool", "average"}, {"QLinearAveragePool", "average"}};
}
void check_inputs(const op_desc& opd, const std::vector<instruction_ref>& args) const
{
const auto& in_x = args[0];
const auto& zero_pt_x = args[2];
const auto& zero_pt_y = args[4];
const auto onnx_name = opd.onnx_name;
if(in_x->get_shape().ndim() <= 2)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: input dimensions too small");
MIGRAPHX_THROW(onnx_name + ": input dimensions too small");
auto type_x = in_x->get_shape().type();
if(type_x != migraphx::shape::int8_type and type_x != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: unsupported input type");
MIGRAPHX_THROW(onnx_name + ": unsupported input type");
const auto& zero_pt_x = args[2];
if(type_x != zero_pt_x->get_shape().type())
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: mismatched type: input zero point");
MIGRAPHX_THROW(onnx_name + ": mismatched type: input zero point");
if(args.size() == 5)
{
const auto& zero_pt_y = args[4];
if(type_x != zero_pt_y->get_shape().type())
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: mismatched type: output zero point");
MIGRAPHX_THROW(onnx_name + ": mismatched type: output zero point");
}
}
instruction_ref parse(const op_desc& /* opd */,
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
if(contains(info.attributes, "channel_last"))
{
int channels_last =
parser.parse_value(info.attributes.at("channels_last")).template at<int>();
if(channels_last != 0)
MIGRAPHX_THROW(
"QLINEARGLOBALAVERAGEPOOL: channels_last (N x D1..Dn x C) is not supported");
MIGRAPHX_THROW(opd.onnx_name + ": channels_last (N x D1..Dn x C) is not supported");
}
check_inputs(args);
check_inputs(opd, args);
// Input: X
......@@ -128,21 +95,18 @@ struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool
const auto& zero_pt_x = args[2];
auto dquant_x = bcast_qdq_instr("dequantizelinear", in_x, scale_x, zero_pt_x, info);
// Output Y = globalaveragepool(X)
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = in_x->get_shape().lens();
std::vector<size_t> lengths(lens.begin() + 2, lens.end());
op.lengths = lengths;
op.padding = std::vector<size_t>(lens.size());
auto out_y = info.add_instruction(op, dquant_x);
// Output Y = pooling_op(X)
const auto& scale_y = args[3];
const auto& zero_pt_y = args[4];
auto out_y = add_pooling_op(opd, info, dquant_x);
auto out_quant_y = bcast_qdq_instr("quantizelinear", out_y, scale_y, zero_pt_y, info);
const auto& in_scale_y = args[3];
// zero_pt for Y is supplied as the last optional argument..
if(args.size() == 5)
return (bcast_qdq_instr("quantizelinear", out_y, in_scale_y, args[4], info));
return out_quant_y;
// if no zero_pt: just broadcast the scale..
auto bcast_scale_y = bcast_scalar_instr(out_y->get_shape(), in_scale_y, info);
return (info.add_instruction(migraphx::make_op("quantizelinear"), out_y, bcast_scale_y));
}
};
......
......@@ -39,15 +39,17 @@ struct parse_scatternd : op_parser<parse_scatternd>
const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const
{
std::string reduction = "none";
if(contains(info.attributes, "reduction"))
{
if(info.attributes.at("reduction").s() == "add")
return info.add_instruction(migraphx::make_op("scatternd_add"), args);
if(info.attributes.at("reduction").s() == "mul")
return info.add_instruction(migraphx::make_op("scatternd_mul"), args);
reduction = info.attributes.at("reduction").s();
if(not contains({"none", "add", "mul", "min", "max"}, reduction))
{
MIGRAPHX_THROW("PARSE_SCATTERND: unsupported reduction mode " + reduction);
}
}
return info.add_instruction(migraphx::make_op("scatternd_none"), args);
return info.add_instruction(migraphx::make_op("scatternd_" + reduction), args);
}
};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,46 +21,72 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/pad.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <optional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
namespace onnx {
// generate unique output stream y, given input stream x;
//
// case unsorted:
// input x: [2, 1, 1, 3, 4, 3], attr_sorted = 0;
// output(s):
// y: [2, 1, 3, 4] --- the unique output
// y_indices: [0, 1, 3, 4] --- first incidence, in terms of indices of x
// x_rev_indices: [0, 1, 1, 2, 3, 2] --- x seen in terms of indices of y
// y_count: [1, 2, 2, 1] -- count at each y_index. sum = len(x)
//
// case sorted:
// input x: [2, 1, 1, 3, 4, 3], attr_sorted = 1;
// output(s):
// y: [1, 2, 3, 4] --- the unique output
// y_indices: [1, 0, 3, 4] --- first incidence, in terms of indices of x
// x_rev_indices: [1, 0, 0, 2, 3, 2] --- x seen in terms of indices of y
// y_count: [2, 1, 2, 1] -- count at each y_index. sum = len(x)
argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
struct parse_unique : op_parser<parse_unique>
{
std::size_t nelements = arg1.get_shape().elements();
hip_visit_all(result, arg1)([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
using hip_index = typename decltype(output)::hip_index;
type device_val = pad_clamp<host_type<type>>(value);
gs_launch(stream, result.get_shape().elements())(
[=](auto i) __device__ { output.data()[i] = device_val; });
hip_index offsets;
std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin());
gs_launch(stream, nelements)([=](auto i) __device__ {
auto idx = input.get_shape().multi(i);
for(std::size_t j = 0; j < offsets.size(); j++)
std::vector<op_desc> operators() const { return {{"Unique"}}; }
std::vector<instruction_ref> parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
idx[j] += offsets[j];
int64_t sorted = 1; // default = sorted.
if(contains(info.attributes, "sorted"))
sorted = parser.parse_value(info.attributes.at("sorted")).at<int>();
std::optional<int64_t> axis;
if(contains(info.attributes, "axis"))
{
auto n_dim = args[0]->get_shape().ndim();
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
axis = tune_axis(n_dim, *axis, opd.op_name);
}
migraphx::argument data_arg = args.back()->eval();
auto opr = axis ? migraphx::make_op("unique", {{"axis", *axis}, {"sorted", sorted}})
: migraphx::make_op("unique", {{"sorted", sorted}});
auto u_opr = info.add_instruction(opr, args.at(0));
auto i_y = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), u_opr);
auto i_y_idx = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), u_opr);
auto i_x_idx = info.add_instruction(make_op("get_tuple_elem", {{"index", 2}}), u_opr);
auto i_count = info.add_instruction(make_op("get_tuple_elem", {{"index", 3}}), u_opr);
return {i_y, i_y_idx, i_x_idx, i_count};
}
output[idx] = input.data()[i];
});
});
return result;
}
};
} // namespace device
} // namespace gpu
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
value handle_pooling_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values)
{
auto kdims = in_shape.ndim() - 2;
if(starts_with(opd.onnx_name, "Global") or starts_with(opd.onnx_name, "QLinearGlobal"))
{
// if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
}
else
{
// works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
}
if(contains(info.attributes, "ceil_mode"))
{
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
if(contains(info.attributes, "dilations"))
{
values["dilations"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilations"]));
check_attr_sizes(
kdims, values["dilations"].size(), "PARSE_POOLING: inconsistent dilations");
}
// lp_order attribute
if(contains(info.attributes, "p"))
{
values["lp_order"] = info.attributes.at("p").i();
}
// ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
return values;
}
instruction_ref add_pooling_op(const op_desc& opd, onnx_parser::node_info info, instruction_ref l0)
{
std::string mode = opd.op_name;
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_pooling_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
}
if(paddings.size() != 2 * kdims)
{
paddings.resize(kdims * 2);
std::fill_n(paddings.begin(), 2 * kdims, 0);
}
if(values["padding"].size() != kdims)
{
values["padding"].resize(kdims);
std::fill_n(values["padding"].begin(), kdims, 0);
}
if(values["stride"].size() != kdims)
{
values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1);
}
if(values["dilations"].size() != kdims)
{
values["dilations"].resize(kdims);
std::fill_n(values["dilations"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings;
// TODO: add parsing for dilations
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
auto auto_pad = to_upper(info.attributes["auto_pad"].s());
// don't use the given padding sizes, if any
// values["padding"].clear();
if(in_shape.dynamic())
{
// set padding_mode to trigger auto padding at runtime
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
values["dilations"].to_vector<std::size_t>(),
in_shape.lens(),
paddings);
values["padding"] = paddings;
// default padding_mode indicates that padding sizes are not calculated dynamically
values["padding_mode"] = migraphx::op::padding_mode_t::default_;
}
}
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(not slice_start.empty())
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()});
// make an op just to get its output shape
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information
slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2,
out_lens.end(),
slice_start.begin(),
slice_end.begin(),
[](auto i, auto j) { return i + j; });
}
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
op.from_value(values);
auto l1 = info.add_instruction(op, l0);
if(not slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}), l1);
}
return l1;
}
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -68,6 +68,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t)
case st::int32_type: return dt::s32;
case st::int8_type: return dt::s8;
case st::uint8_type: return dt::u8;
case st::fp8e4m3fnuz_type: MIGRAPHX_THROW("fp8e4m3fnuz unsupported in DNNL");
default: MIGRAPHX_THROW("Unsupported data type");
}
}
......
......@@ -340,7 +340,6 @@ struct cpu_apply
{"reduce_min", "reduction_min"},
{"reduce_sum", "reduction_sum"},
});
extend_op("concat", "dnnl::concat");
extend_op("contiguous", "dnnl::reorder");
extend_op("convolution", "dnnl::convolution");
......@@ -376,6 +375,12 @@ struct cpu_apply
// Apply these operators first so the inputs can be const folded
for(auto it : iterator_for(*modl))
{
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
continue;
if(it->name() == "pow")
{
apply_pow(it);
......@@ -383,6 +388,12 @@ struct cpu_apply
}
for(auto it : iterator_for(*modl))
{
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
continue;
if(it->name() == "pooling")
{
apply_pooling(it);
......
......@@ -126,7 +126,6 @@ add_library(migraphx_gpu
fuse_ck.cpp
fuse_mlir.cpp
fuse_ops.cpp
gather.cpp
gemm_impl.cpp
hip.cpp
kernel.cpp
......@@ -140,7 +139,6 @@ add_library(migraphx_gpu
nonzero.cpp
pack_args.cpp
prefuse_ops.cpp
pad.cpp
perfdb.cpp
pooling.cpp
reverse.cpp
......@@ -168,12 +166,10 @@ endfunction()
register_migraphx_gpu_ops(hip_
argmax
argmin
gather
logsoftmax
loop
multinomial
nonzero
pad
prefix_scan_sum
reverse
scatter
......@@ -263,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
# rocblas FP8 API
check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
......@@ -292,10 +290,18 @@ else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()
if(HAS_ROCBLAS_FP8_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations")
else()
message(STATUS "rocBLAS does not have Fp8 Beta API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1)
endif()
add_subdirectory(driver)
......
......@@ -54,6 +54,11 @@ vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes)
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis};
......@@ -86,6 +91,11 @@ vectorize vectorize::elements(std::size_t axis,
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(inputs.empty())
return {1, axis};
std::size_t n = std::max_element(inputs.begin(),
......
......@@ -38,6 +38,18 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool mlir_enabled()
{
......@@ -49,6 +61,26 @@ bool mlir_enabled()
#endif
}
static bool is_requested(std::string_view option, bool fallback = false)
{
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ',');
return contains(options, option);
}
bool mlir_attention_enabled()
{
#ifdef MIGRAPHX_MLIR
if(not mlir_enabled())
return false;
return is_requested("attention");
#else
return false;
#endif
}
#ifdef MIGRAPHX_MLIR
struct mlir_op
......@@ -62,41 +94,27 @@ struct mlir_op
return pack(f(self.op, "op"));
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
shape compute_shape(const std::vector<shape>& inputs, const std::vector<module_ref>& mods) const
{
module_ref mod = mods[0];
check_shapes{inputs, *this}.packed_or_broadcasted();
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
module_ref mod = mods[0];
auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes;
size_t param_cnt = 0;
std::vector<std::string> names = mod->get_parameter_names();
std::sort(names.begin(), names.end());
for(const std::string& param_name : names)
{
ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
}
for(auto ins : iterator_for(*mod))
{
if(ins->name() == "@param")
{
continue;
}
if(ins->name() == "@literal")
if(ins->name() == "@literal" or ins->name() == "@param")
{
ins_shapes[ins] = ins->get_shape();
continue;
}
if(ins->name() == "@return")
{
auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
return ins_shapes[ins->inputs().at(0)].with_type(type);
}
std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size());
......@@ -112,38 +130,55 @@ struct mlir_op
MIGRAPHX_REGISTER_OP(mlir_op);
namespace {
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
std::tuple<instruction_ref, std::vector<operation>>
get_fusable_input_op_stream(instruction_ref lower_input)
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
instruction_ref upper_input = lower_input;
std::vector<operation> op_stream;
while(contains(
{"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
input->name()))
while(contains({"slice",
"transpose",
"multibroadcast",
"broadcast",
"contiguous",
"reshape",
"squeeze",
"flatten",
"unsqueeze"},
upper_input->name()))
{
operation op = input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
operation op = upper_input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, upper_input->name()))
{
op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
op = migraphx::make_op("reshape", {{"dims", upper_input->get_shape().lens()}});
}
op_stream.push_back(op);
input = input->inputs().at(0);
upper_input = upper_input->inputs().at(0);
}
top_inputs.push_back(input);
return {upper_input, op_stream};
}
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm,
const std::vector<instruction_ref>& gemm_based_op_inputs,
const operation& gemm_based_op)
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op_inputs)
{
auto [upper_input, op_stream] = get_fusable_input_op_stream(input);
top_inputs.push_back(upper_input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
mm->add_parameter("y" + std::to_string(input_cnt++), upper_input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
instruction_ref new_gemm_based_op = mm->add_instruction(gemm_based_op, imm_inputs);
return {new_gemm_based_op, top_inputs};
}
......@@ -205,20 +240,9 @@ auto is_mlir_conv(mlir_mode mode)
});
}
struct find_mlir_fused_ops
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape)
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const
{
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
{
......@@ -228,18 +252,41 @@ struct find_mlir_fused_ops
}
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
instruction_ref mbcast =
mm->add_instruction(make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
}
return ins_map;
}
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i) const
{
std::vector<instruction_ref>
fold_pointwise_mod(instruction_ref pm_ins,
module_ref parent_mod,
const std::unordered_map<instruction_ref, instruction_ref>& ins_map)
{
auto* pm = pm_ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(parent_mod, pm, pm_ins->get_shape());
std::transform(names.begin(),
names.end(),
pm_ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&](auto name, auto input) {
if(ins_map.count(input))
return std::make_pair(pm->get_parameter(name), ins_map.at(input));
return std::make_pair(pm->get_parameter(name),
parent_mod->add_parameter(name, input->get_shape()));
});
return parent_mod->insert_instructions(parent_mod->end(), pm, param_map);
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i)
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
......@@ -300,6 +347,27 @@ struct find_mlir_fused_ops
});
}
return false;
}
MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins)
{
if(ins->name() != "pointwise")
return false;
auto* pm = ins->module_inputs().front();
return std::all_of(pm->begin(), pm->end(), [&](const auto& i) {
return is_pointwise_op_supported_by_mlir(i);
});
}
struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
......@@ -309,29 +377,12 @@ struct find_mlir_fused_ops
auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
// Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not is_pointwise_op_supported_by_mlir(i);
}))
return;
std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op);
std::transform(names.begin(),
names.end(),
ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&, &anchor = anchor_op](auto name, auto input) {
if(input == x_ins)
return std::make_pair(pm->get_parameter(name), anchor);
return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape()));
});
mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
mm->add_return(fold_pointwise_mod(ins, mm, {{x_ins, anchor_op}}));
std::vector<instruction_ref> inputs;
std::copy_if(ins->inputs().begin(),
......@@ -349,52 +400,103 @@ struct find_mlir_standalone_op
{
mlir_mode mode = mlir_mode::none;
auto matcher() const { return Matcher(mode); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
auto gemm_based_op = r.result;
//
// enable only for fp32/fp16/i8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
}))
return;
static size_t counter = 0;
module_ref mm =
mpm.create_module("mlir_" + conv_based_op->name() + std::to_string(counter++));
mpm.create_module("mlir_" + gemm_based_op->name() + std::to_string(counter++));
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(
conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
gemm_based_op, mlir_op{gemm_based_op->get_operator()}, top_inputs, {mm});
}
};
using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
struct find_mlir_standalone_attention_op
{
auto matcher() const
{
return match::name("gpu::pre_gemm_softmax_gemm").bind("gemm_softmax_gemm");
}
bool is_requested(std::string_view option, bool fallback = false)
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
auto gemm_softmax_gemm = r.instructions["gemm_softmax_gemm"];
std::vector<instruction_ref> inputs;
mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto gemm0_inputs = gemm_softmax_gemm->inputs();
gemm0_inputs.pop_back();
auto [gemm0, top_gemm0_inputs] =
fuse_input_ops_and_gemm_based_op(mm, gemm0_inputs, make_op("dot"));
inputs.insert(inputs.begin(), top_gemm0_inputs.begin(), top_gemm0_inputs.end());
// handle scale
auto v = gemm_softmax_gemm->get_operator().to_value();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
auto scale_lit = mm->add_literal(literal{shape{gemm0->get_shape().type()}, {scale}});
instruction_ref scale_lit_mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", gemm0->get_shape().lens()}}), scale_lit);
auto scaled_gemm0 = mm->add_instruction(make_op("mul"), gemm0, scale_lit_mbcast);
auto softmax = mm->add_instruction(
make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}), scaled_gemm0);
auto [old_upper_v, upper_v_op_stream] =
get_fusable_input_op_stream(gemm_softmax_gemm->inputs()[2]);
instruction_ref new_upper_v = mm->add_parameter("z", old_upper_v->get_shape());
for(const auto& op : reverse(upper_v_op_stream))
{
new_upper_v = mm->add_instruction(op, {new_upper_v});
}
inputs.push_back(old_upper_v);
auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v});
ins_map[gemm_softmax_gemm] = gemm1;
auto ins_to_replace = gemm1;
auto ins_to_be_replaced = gemm_softmax_gemm;
if(r.instructions.find("trailing_pm") != r.instructions.end())
{
ins_to_replace = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0];
std::copy_if(r.instructions["trailing_pm"]->inputs().begin(),
r.instructions["trailing_pm"]->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != gemm_softmax_gemm; });
ins_to_be_replaced = r.instructions["trailing_pm"];
}
mm->add_return({ins_to_replace});
mpm.get_module().replace_instruction(
ins_to_be_replaced, mlir_op{gemm1->get_operator()}, inputs, {mm});
}
};
struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
{
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ',');
return contains(options, option);
}
auto matcher() const
{
auto standalone_matcher = find_mlir_standalone_attention_op::matcher();
return mlir_pointwise()(
match::any_of[match::inputs()](standalone_matcher).bind("trailing_pm"));
;
}
};
} // namespace
#endif // MIGRAPHX_MLIR
......@@ -416,13 +518,20 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
mlir_mode mode =
(enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
// Attention offloads; default disabled
if(mlir_attention_enabled())
{
match::find_matches(mpm, find_mlir_attention_fused_ops{});
match::find_matches(mpm, find_mlir_standalone_attention_op{});
}
match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)});
match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else
(void)mpm;
......
......@@ -22,11 +22,14 @@
* THE SOFTWARE.
*/
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <type_traits>
using microseconds = std::chrono::duration<double, std::micro>;
......@@ -34,6 +37,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct rb_compute_type
{
int type = 0;
rb_compute_type(rocblas_datatype t) : type(static_cast<int>(t)) {}
rb_compute_type(rocblas_computetype t) : type(static_cast<int>(t)) {}
operator rocblas_datatype() const { return static_cast<rocblas_datatype>(type); }
operator rocblas_computetype() const { return static_cast<rocblas_computetype>(type); }
};
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type)
{
......@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::fp8e4m3fnuz_type:
case shape::fp8e4m3fnuz_type: return rocblas_datatype_f8_r;
case shape::tuple_type:
case shape::bool_type:
case shape::uint16_type:
......@@ -183,12 +200,17 @@ struct gemm_impl
{
output_type = rocblas_datatype_i32_r;
}
compute_type = output_type;
compute_type = rb_compute_type{output_type};
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
if(arg_type == rocblas_datatype_f8_r)
{
assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r);
compute_type = rocblas_compute_type_f32;
}
auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens();
......@@ -216,6 +238,34 @@ struct gemm_impl
}
void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const
{
#ifdef MIGRAPHX_USE_ROCBLAS_FP8_API
if(rocblas_fp8_available() and
std::any_of(input_args.begin(), input_args.end(), [](const auto i) {
return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3,
common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
}
else
#endif
{
if(strided_batched)
{
......@@ -236,6 +286,7 @@ struct gemm_impl
gemm_flags);
}
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
auto validate(context& ctx, const std::vector<shape>& input_shapes, int32_t solution_idx) const
......@@ -331,7 +382,6 @@ struct gemm_impl
num_matrices,
compute_type);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
......@@ -366,6 +416,7 @@ struct gemm_impl
ldd,
compute_type);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
......@@ -481,8 +532,8 @@ struct gemm_impl
rocblas_int b_stride = 0;
rocblas_int c_stride = 0;
rocblas_int d_stride = 0;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_datatype arg_type = rocblas_datatype_f32_r;
rb_compute_type compute_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true;
bool is_3inputs = true;
......
......@@ -34,6 +34,7 @@ struct module_pass_manager;
namespace gpu {
MIGRAPHX_GPU_EXPORT bool mlir_enabled();
MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled();
struct MIGRAPHX_GPU_EXPORT fuse_mlir
{
......
......@@ -66,6 +66,10 @@ struct gemm_softmax_gemm
}
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
static bool is_mlir_supported_type(shape::type_t t)
{
return contains({shape::type_t::float_type, shape::half_type}, t);
}
};
} // namespace gpu
......
......@@ -40,6 +40,8 @@ struct context;
MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag();
MIGRAPHX_GPU_EXPORT bool rocblas_fp8_available();
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,41 +21,58 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#ifndef MIGRAPHX_GUARD_JIT_SCATTER_HPP
#define MIGRAPHX_GUARD_JIT_SCATTER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_pad
template <typename Derived>
struct scatter_compiler : compiler<Derived>
{
op::pad op;
template <class Self, class F>
static auto reflect(Self& self, F f)
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return migraphx::reflect(self.op, f);
const auto inputs =
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()});
hip_compile_options options;
options.set_launch_params(op.to_value(), compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = derived().get_kernel_name(op);
options.virtual_inputs = inputs;
// The compiler protests the inequality comparison in assign_mul when pertaining to floating
// point, despite it making sense in the context. Thus the warning removal.
options.params += "-Wno-float-equal";
const auto src = derived().make_interpolated_string(op);
return prepend_copy_data_to_output(compile_hip_code_object(src, options));
}
std::string name() const { return "gpu::pad"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
compiler_replace prepend_copy_data_to_output(const operation& co) const
{
return shapes.size() - 1;
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
std::string get_kernel_name(const operation& op) const { return op.name() + "_kernel"; }
const Derived& derived() const { return static_cast<const Derived&>(*this); }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -21,11 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include "scatter.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void*
)__migraphx__";
struct scatternd_compiler : compiler<scatternd_compiler>
struct scatternd_compiler : scatter_compiler<scatternd_compiler>
{
std::vector<std::string> names() const
{
return {"scatternd_none", "scatternd_add", "scatternd_mul"};
return {
"scatternd_none", "scatternd_add", "scatternd_mul", "scatternd_min", "scatternd_max"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
std::string make_interpolated_string(const operation& op) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"});
auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}});
return compile_hip_code_object(src, options);
const auto reduction = op.name().substr(std::char_traits<char>::length("scatternd_"));
return interpolate_string(scatternd_kernel, {{"reduction", "assign_" + reduction}});
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(
ctx,
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& co) const
{
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
std::string get_kernel_name(const operation&) const { return "scatternd_kernel"; }
};
} // namespace gpu
......
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
template <typename To,
typename From,
MIGRAPHX_REQUIRES(is_trivially_copyable<To>{} and is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
return __builtin_bit_cast(To, fr);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
......@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x)
return y;
}
template <unsigned int DppCtrl,
unsigned int RowMask = 0xf,
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
template <class T, class F>
__device__ T dpp_op(T& x, F f)
{
static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type
......@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x)
input.data = x;
for(index_int i = 0; i < n; i++)
{
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
output.reg[i] = f(input.reg[i]);
}
return output.data;
}
template <unsigned int DppCtrl,
unsigned int RowMask = 0xf,
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{
return dpp_op(x,
[](auto i) { return __hip_move_dpp(i, DppCtrl, RowMask, BankMask, BoundCtrl); });
}
template <unsigned int Mask, class T>
__device__ T dpp_swizzle(T& x)
{
return dpp_op(x, [](auto i) { return __hip_ds_swizzle(i, Mask); });
}
#endif // MIGRAPHX_HAS_DPP
} // namespace migraphx
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment