Unverified Commit 3ae5f9ed authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into concat2

parents 6d5a34d2 785ff7d7
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#include <migraphx/config.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
value handle_pooling_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values);
instruction_ref add_pooling_op(const op_desc& opd, onnx_parser::node_info info, instruction_ref l0);
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Neg", "neg"}, {"Neg", "neg"},
{"Reciprocal", "recip"}, {"Reciprocal", "recip"},
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "nearbyint"},
{"Sigmoid", "sigmoid"}, {"Sigmoid", "sigmoid"},
{"Sign", "sign"}, {"Sign", "sign"},
{"Sin", "sin"}, {"Sin", "sin"},
......
...@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv ...@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
} }
} }
void lstm_transpose_inputs(onnx_parser::node_info& info, std::vector<instruction_ref>& args)
{
std::vector<int64_t> perm{1, 0, 2};
args[0] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]);
if(args.size() >= 6 and not args[5]->is_undefined())
{
args[5] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[5]);
}
if(args.size() >= 7 and not args[6]->is_undefined())
{
args[6] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[6]);
}
}
void lstm_transpose_outputs(onnx_parser::node_info& info,
instruction_ref& hidden_states,
instruction_ref& last_output,
instruction_ref& last_cell_output)
{
std::vector<int64_t> perm_hs{2, 0, 1, 3};
hidden_states =
info.add_instruction(make_op("transpose", {{"permutation", perm_hs}}), hidden_states);
std::vector<int64_t> perm_last{1, 0, 2};
last_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_output);
last_cell_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_cell_output);
}
struct parse_lstm : op_parser<parse_lstm> struct parse_lstm : op_parser<parse_lstm>
{ {
std::vector<op_desc> operators() const { return {{"LSTM"}}; } std::vector<op_desc> operators() const { return {{"LSTM"}}; }
...@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>(); input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>();
} }
int layout = 0;
if(contains(info.attributes, "layout"))
{
layout = parser.parse_value(info.attributes.at("layout")).at<int>();
}
// append undefined opeator to make 6 arguments // append undefined opeator to make 6 arguments
if(args.size() < 8) if(args.size() < 8)
{ {
...@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args.insert(args.end(), 8 - args.size(), ins); args.insert(args.end(), 8 - args.size(), ins);
} }
if(layout != 0)
{
lstm_transpose_inputs(info, args);
}
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("lstm", auto hidden_states = info.add_instruction(make_op("lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto last_cell_output = auto last_cell_output =
info.add_instruction(make_op("rnn_last_cell_output"), hidden_states); info.add_instruction(make_op("rnn_last_cell_output"), hidden_states);
if(layout != 0)
{
lstm_transpose_outputs(info, hidden_states, last_output, last_cell_output);
}
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
}; };
......
...@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
// use literal. The array populated by random_uniform may have any shape, as long its // use literal. The array populated by random_uniform may have any shape, as long its
// number of elements is batch_size * sample_size . // number of elements is batch_size * sample_size .
size_t batch_size = s0.lens().front(); size_t batch_size = s0.lens().front();
auto rand_dummy = info.add_literal( auto rand_dummy = info.add_literal(migraphx::literal{
migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}}); migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
std::vector<float>(batch_size * sample_size)});
randoms = randoms =
info.add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy); info.add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
} }
......
...@@ -22,14 +22,8 @@ ...@@ -22,14 +22,8 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/pooling.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,68 +33,14 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -39,68 +33,14 @@ struct parse_pooling : op_parser<parse_pooling>
{ {
std::vector<op_desc> operators() const std::vector<op_desc> operators() const
{ {
return {{"AveragePool", "average"}, return {
{"GlobalAveragePool", "average"}, {"AveragePool", "average"},
{"GlobalMaxPool", "max"}, {"GlobalAveragePool", "average"},
{"MaxPool", "max"}, {"GlobalMaxPool", "max"},
{"LpPool", "lpnorm"}, {"MaxPool", "max"},
{"GlobalLpPool", "lpnorm"}}; {"LpPool", "lpnorm"},
} {"GlobalLpPool", "lpnorm"},
};
value handle_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values) const
{
auto kdims = in_shape.ndim() - 2;
if(starts_with(opd.onnx_name, "Global"))
{
// if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
}
else
{
// works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
}
if(contains(info.attributes, "ceil_mode"))
{
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
// lp_order attribute
if(contains(info.attributes, "p"))
{
values["lp_order"] = info.attributes.at("p").i();
}
// ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
return values;
} }
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
...@@ -108,144 +48,8 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -108,144 +48,8 @@ struct parse_pooling : op_parser<parse_pooling>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
std::string mode = opd.op_name; return add_pooling_op(opd, std::move(info), args[0]);
const std::unordered_map<std::string, op::pooling_mode> mode_map = { };
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
}
if(paddings.size() != 2 * kdims)
{
paddings.resize(kdims * 2);
std::fill_n(paddings.begin(), 2 * kdims, 0);
}
if(values["padding"].size() != kdims)
{
values["padding"].resize(kdims);
std::fill_n(values["padding"].begin(), kdims, 0);
}
if(values["stride"].size() != kdims)
{
values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings;
// TODO: add parsing for dilations
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
auto auto_pad = to_upper(info.attributes["auto_pad"].s());
// don't use the given padding sizes, if any
// values["padding"].clear();
if(in_shape.dynamic())
{
// set padding_mode to trigger auto padding at runtime
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
std::vector<size_t>(in_shape.ndim() - 2, 1),
in_shape.lens(),
paddings);
values["padding"] = paddings;
// default padding_mode indicates that padding sizes are not calculated dynamically
values["padding_mode"] = migraphx::op::padding_mode_t::default_;
}
}
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(not slice_start.empty())
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()});
// make an op just to get its output shape
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information
slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2,
out_lens.end(),
slice_start.begin(),
slice_end.begin(),
[](auto i, auto j) { return i + j; });
}
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
op.from_value(values);
auto l1 = info.add_instruction(op, l0);
if(not slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}),
l1);
}
return l1;
}
}; };
} // namespace onnx } // namespace onnx
......
...@@ -36,7 +36,7 @@ namespace onnx { ...@@ -36,7 +36,7 @@ namespace onnx {
/* /*
********************************************************************************* *********************************************************************************
* Reference: see QLinearAdd in * * Reference: see QLinearAdd, QLinearMul in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md * * https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
********************************************************************************* *********************************************************************************
...@@ -49,6 +49,17 @@ namespace onnx { ...@@ -49,6 +49,17 @@ namespace onnx {
This version of the operator has been available since version 1 of the 'com.microsoft' operator This version of the operator has been available since version 1 of the 'com.microsoft' operator
set. set.
com.microsoft.QLinearMul
Performs element-wise binary multiplication on 8 bit data types (with Numpy-style broadcasting
support).
C = ((A - A_zero_point) * (B - B_zero_point)) * (A_scale * B_scale)/C_scale + C_zero_point
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
General definition of binary QLinear* ops:
Inputs (7 - 8) Inputs (7 - 8)
A : T A : T
First operand. First operand.
...@@ -88,15 +99,18 @@ namespace onnx { ...@@ -88,15 +99,18 @@ namespace onnx {
*/ */
struct parse_qlinearadd : op_parser<parse_qlinearadd> struct parse_qlinearbinary : op_parser<parse_qlinearbinary>
{ {
std::vector<op_desc> operators() const { return {{"QLinearAdd"}}; } std::vector<op_desc> operators() const
{
return {{"QLinearAdd", "add"}, {"QLinearMul", "mul"}};
}
// basic type checking for QLinearAdd Operator // basic type checking for binary QLinear Operator
void check_inputs(const std::vector<instruction_ref>& args) const void check_inputs(const std::vector<instruction_ref>& args, const std::string& op_name) const
{ {
if(args.size() < 7) if(args.size() < 7)
MIGRAPHX_THROW("QLINEARADD: missing inputs"); MIGRAPHX_THROW(op_name + ": missing inputs");
const auto& in_a = args[0]; const auto& in_a = args[0];
const auto& in_b = args[3]; const auto& in_b = args[3];
...@@ -107,19 +121,19 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd> ...@@ -107,19 +121,19 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd>
auto type_a = sh_a.type(); auto type_a = sh_a.type();
auto type_b = sh_b.type(); auto type_b = sh_b.type();
if(type_a != migraphx::shape::int8_type and type_a != migraphx::shape::uint8_type) if(type_a != migraphx::shape::int8_type and type_a != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARADD: unsupported input type"); MIGRAPHX_THROW(op_name + ": unsupported input type");
if(type_b != migraphx::shape::int8_type and type_b != migraphx::shape::uint8_type) if(type_b != migraphx::shape::int8_type and type_b != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARADD: unsupported input type"); MIGRAPHX_THROW(op_name + ": unsupported input type");
if(type_a != type_b) if(type_a != type_b)
MIGRAPHX_THROW("QLINEARADD: mismatched input types"); MIGRAPHX_THROW(op_name + ": mismatched input types");
} }
instruction_ref parse(const op_desc& /* opd */, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
check_inputs(args); check_inputs(args, opd.op_name);
// A // A
const auto& in_a = args[0]; const auto& in_a = args[0];
...@@ -134,8 +148,8 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd> ...@@ -134,8 +148,8 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd>
const auto& in_zero_pt_b = args[5]; const auto& in_zero_pt_b = args[5];
auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info); auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info);
// C = A + B // C = op(A, B)
auto out_c = info.add_common_op("add", dquant_a, dquant_b); auto out_c = info.add_common_op(opd.op_name, dquant_a, dquant_b);
const auto& in_scale_c = args[6]; const auto& in_scale_c = args[6];
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -36,90 +37,56 @@ namespace onnx { ...@@ -36,90 +37,56 @@ namespace onnx {
/* /*
********************************************************************************* *********************************************************************************
* Reference: see QLinearGlobalAveragePool in * * Reference: see QLinearAveragePool and QLinearGlobalAveragePool in *
* github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md * * github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
********************************************************************************* *********************************************************************************
*/
QLinearGlobalAveragePool consumes an input tensor X and applies struct parse_qlinearpooling : op_parser<parse_qlinearpooling>
Average pooling across the values in the same channel. This is
equivalent to AveragePool with kernel size equal to the spatial
dimension of input tensor. Input is of type uint8_t or int8_t.
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
Attributes
channels_last : int
Inputs
X : T
Input data tensor from the previous operator; According to channels_last, dimensions for image case
are (N x C x H x W), or (N x H x W x C) where N is the batch size, C is the number of channels, and
H and W are the height and the width of the data. For non image case, the dimensions are in the form
of (N x C x D1 x D2 ... Dn), or (N x D1 X D2 ... Dn x C) where N is the batch size.
x_scale : tensor(float)
Scale of quantized input 'X'. It must be a scalar.
x_zero_point : T
Zero point tensor for input 'X'. It must be a scalar.
y_scale : tensor(float)
Scale of quantized output 'Y'. It must be a scalar.
y_zero_point : T
Zero point tensor for output 'Y'. It must be a scalar.
Outputs
Y : T
Output data tensor from pooling across the input tensor. The output tensor has the same rank as the
input. with the N and C value keep it value, while the other dimensions are all 1. Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to signed/unsigned int8 tensors.
*/
struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool>
{ {
std::vector<op_desc> operators() const { return {{"QLinearGlobalAveragePool"}}; } std::vector<op_desc> operators() const
// basic type checking for QLinearGlobalAveragePool Operator
void check_inputs(const std::vector<instruction_ref>& args) const
{ {
if(args.size() < 5) return {{"QLinearGlobalAveragePool", "average"}, {"QLinearAveragePool", "average"}};
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: missing inputs"); }
const auto& in_x = args[0]; void check_inputs(const op_desc& opd, const std::vector<instruction_ref>& args) const
const auto& zero_pt_x = args[2]; {
const auto& zero_pt_y = args[4]; const auto& in_x = args[0];
const auto onnx_name = opd.onnx_name;
if(in_x->get_shape().ndim() <= 2) if(in_x->get_shape().ndim() <= 2)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: input dimensions too small"); MIGRAPHX_THROW(onnx_name + ": input dimensions too small");
auto type_x = in_x->get_shape().type(); auto type_x = in_x->get_shape().type();
if(type_x != migraphx::shape::int8_type and type_x != migraphx::shape::uint8_type) if(type_x != migraphx::shape::int8_type and type_x != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: unsupported input type"); MIGRAPHX_THROW(onnx_name + ": unsupported input type");
const auto& zero_pt_x = args[2];
if(type_x != zero_pt_x->get_shape().type()) if(type_x != zero_pt_x->get_shape().type())
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: mismatched type: input zero point"); MIGRAPHX_THROW(onnx_name + ": mismatched type: input zero point");
if(type_x != zero_pt_y->get_shape().type()) if(args.size() == 5)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: mismatched type: output zero point"); {
const auto& zero_pt_y = args[4];
if(type_x != zero_pt_y->get_shape().type())
MIGRAPHX_THROW(onnx_name + ": mismatched type: output zero point");
}
} }
instruction_ref parse(const op_desc& /* opd */, instruction_ref parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
int channels_last = if(contains(info.attributes, "channel_last"))
parser.parse_value(info.attributes.at("channels_last")).template at<int>(); {
if(channels_last != 0) int channels_last =
MIGRAPHX_THROW( parser.parse_value(info.attributes.at("channels_last")).template at<int>();
"QLINEARGLOBALAVERAGEPOOL: channels_last (N x D1..Dn x C) is not supported"); if(channels_last != 0)
MIGRAPHX_THROW(opd.onnx_name + ": channels_last (N x D1..Dn x C) is not supported");
}
check_inputs(args); check_inputs(opd, args);
// Input: X // Input: X
...@@ -128,21 +95,18 @@ struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool ...@@ -128,21 +95,18 @@ struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool
const auto& zero_pt_x = args[2]; const auto& zero_pt_x = args[2];
auto dquant_x = bcast_qdq_instr("dequantizelinear", in_x, scale_x, zero_pt_x, info); auto dquant_x = bcast_qdq_instr("dequantizelinear", in_x, scale_x, zero_pt_x, info);
// Output Y = globalaveragepool(X) // Output Y = pooling_op(X)
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = in_x->get_shape().lens();
std::vector<size_t> lengths(lens.begin() + 2, lens.end());
op.lengths = lengths;
op.padding = std::vector<size_t>(lens.size());
auto out_y = info.add_instruction(op, dquant_x);
const auto& scale_y = args[3]; auto out_y = add_pooling_op(opd, info, dquant_x);
const auto& zero_pt_y = args[4];
auto out_quant_y = bcast_qdq_instr("quantizelinear", out_y, scale_y, zero_pt_y, info); const auto& in_scale_y = args[3];
// zero_pt for Y is supplied as the last optional argument..
if(args.size() == 5)
return (bcast_qdq_instr("quantizelinear", out_y, in_scale_y, args[4], info));
return out_quant_y; // if no zero_pt: just broadcast the scale..
auto bcast_scale_y = bcast_scalar_instr(out_y->get_shape(), in_scale_y, info);
return (info.add_instruction(migraphx::make_op("quantizelinear"), out_y, bcast_scale_y));
} }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
/*
*********************************************************************************
* Reference: see QLinearSigmoid, QLinearLeakyRelu in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
com.microsoft.QLinearSigmoid
QLinearSigmoid takes quantized input data (Tensor), and quantize parameter for output, and produces
one output data (Tensor) where the function f(x) = quantize(Sigmoid(dequantize(x))), is applied to
the data tensor elementwise. Where the function Sigmoid(x) = 1 / (1 + exp(-x))
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
*****************************************************************************************************
com.microsoft.QLinearLeakyRelu
QLinearLeakyRelu takes quantized input data (Tensor), an argument alpha, and quantize parameter for
output, and produces one output data (Tensor) where the function f(x) = quantize(alpha *
dequantize(x)) for dequantize(x) < 0, f(x) = quantize(dequantize(x)) for dequantize(x) >= 0, is
applied to the data tensor elementwise.
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
Attributes
alpha : float
Coefficient of leakage.
******************************************************************************************************
Generic input layout of QLinear unary operators:
Inputs (4 - 5)
X : T
Input tensor
X_scale : tensor(float)
Input X's scale. It's a scalar, which means a per-tensor/layer quantization.
X_zero_point (optional) : T
Input X's zero point. Default value is 0 if it's not specified. It's a scalar, which means a
per-tensor/layer quantization.
Y_scale : tensor(float) Output Y's scale. It's a scalar, which means
a per-tensor/layer quantization.
Y_zero_point (optional) : T Output Y's zero point. Default value is
0 if it's not specified. It's a scalar, which means a per-tensor/layer quantization.
Outputs
Y : T
Output tensor
Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to 8 bit tensors.
*/
struct parse_qlinearunary : op_parser<parse_qlinearunary>
{
std::vector<op_desc> operators() const
{
return {{"QLinearSigmoid", "sigmoid"}, {"QLinearLeakyRelu", "leaky_relu"}};
}
void check_inputs(const op_desc& opd, const std::vector<instruction_ref>& args) const
{
if(args.size() < 4)
MIGRAPHX_THROW(opd.op_name + ": missing inputs");
const auto& in_x = args[0];
auto sh_x = in_x->get_shape();
auto type_x = sh_x.type();
if(type_x != migraphx::shape::int8_type and type_x != migraphx::shape::uint8_type)
MIGRAPHX_THROW(opd.op_name + ": unsupported input type");
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
check_inputs(opd, args);
// X
const auto& in_x = args[0];
const auto& in_scale_x = args[1];
const auto& in_zero_pt_x = args[2];
auto dquant_x = bcast_qdq_instr("dequantizelinear", in_x, in_scale_x, in_zero_pt_x, info);
// Y = (op(dequantizelinear(x))
auto op = parser.load(opd.op_name, info);
auto y = info.add_instruction(op, dquant_x);
const auto& in_scale_y = args[3];
// zero_pt for Y is supplied as the last optional argument..
if(args.size() == 5)
return (bcast_qdq_instr("quantizelinear", y, in_scale_y, args[4], info));
// if no zero_pt: just broadcast the scale..
auto bcast_scale_sigm = bcast_scalar_instr(y->get_shape(), in_scale_y, info);
return (info.add_instruction(migraphx::make_op("quantizelinear"), y, bcast_scale_sigm));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -181,6 +181,76 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr) ...@@ -181,6 +181,76 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
return nearest_mode; return nearest_mode;
} }
static std::vector<double> get_scales(const onnx_parser::attribute_map& attr)
{
std::vector<double> scales;
if(contains(attr, "scales"))
{
copy(attr.at("scales").floats(), std::back_inserter(scales));
}
return scales;
}
static void parse_args(const std::vector<instruction_ref>& args,
const std::vector<size_t>& in_lens,
const std::string& op_name,
std::vector<double>& vec_scale,
std::vector<std::size_t>& out_lens)
{
for(const auto& arg : args)
{
if(arg->name() == "undefined" or arg == args.front())
{
continue;
}
// skipped empty input
auto lens = arg->get_shape().lens();
if(lens.empty())
{
continue;
}
auto type = arg->get_shape().type();
// output size
if(type == shape::int64_type)
{
auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s,
"PARSE_" + op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_" + op_name +
": specified output size does not match input size");
}
// compute the scale
vec_scale.resize(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return 1.0 * oss / iss; });
}
else
{
// scale input
if(lens[0] == in_lens.size())
{
auto arg_scale = arg->eval();
check_arg_empty(arg_scale,
"PARSE_" + op_name + ": dynamic input scale is not supported!");
arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); });
}
}
}
}
struct parse_resize : op_parser<parse_resize> struct parse_resize : op_parser<parse_resize>
{ {
std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; } std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; }
...@@ -214,72 +284,30 @@ struct parse_resize : op_parser<parse_resize> ...@@ -214,72 +284,30 @@ struct parse_resize : op_parser<parse_resize>
std::vector<std::size_t> out_lens(in_lens.size()); std::vector<std::size_t> out_lens(in_lens.size());
// scale // scale
std::vector<double> vec_scale; std::vector<double> vec_scale = get_scales(info.attributes);
for(const auto& arg : args) // If `scales` was not an attribute, it must be an input
if(vec_scale.empty())
{ {
if(arg->name() == "undefined" or arg == args.front()) // Depending on the args, it *must* populate the `vec_scale`, and might populate
{ // `out_lens`
continue; parse_args(args, in_lens, opd.op_name, vec_scale, out_lens);
} }
// skipped empty input
auto lens = arg->get_shape().lens();
if(lens.empty())
{
continue;
}
auto type = arg->get_shape().type();
// output size
if(type == shape::int64_type)
{
auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified output size does not match input size");
}
// compute the scale if(in_lens.size() != vec_scale.size())
vec_scale.resize(in_lens.size()); {
std::transform(in_lens.begin(), MIGRAPHX_THROW("PARSE_" + opd.op_name + ": ranks of input and scale are different!");
in_lens.end(), }
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return 1.0 * oss / iss; });
}
else
{
// scale input // if the output was not calculated yet, we update it based on the scales
if(lens[0] == in_lens.size()) if(all_of(out_lens.cbegin(), out_lens.cend(), [](auto o) { return o == 0; }))
{ {
auto arg_scale = arg->eval(); std::transform(
check_arg_empty(arg_scale, in_lens.begin(),
"PARSE_" + opd.op_name + in_lens.end(),
": dynamic input scale is not supported!"); vec_scale.begin(),
out_lens.begin(),
arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); }); [&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": ranks of input and scale are different!");
}
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
return static_cast<std::size_t>(idx * scale);
});
}
}
} }
shape out_s{in_s.type(), out_lens}; shape out_s{in_s.type(), out_lens};
...@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize> ...@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize>
// reshape input to one-dimension // reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())}; std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
args[0] = info.make_contiguous(args[0]);
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]); auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
if(mode == "nearest") if(mode == "nearest")
......
...@@ -39,15 +39,17 @@ struct parse_scatternd : op_parser<parse_scatternd> ...@@ -39,15 +39,17 @@ struct parse_scatternd : op_parser<parse_scatternd>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const std::vector<instruction_ref>& args) const
{ {
std::string reduction = "none";
if(contains(info.attributes, "reduction")) if(contains(info.attributes, "reduction"))
{ {
if(info.attributes.at("reduction").s() == "add") reduction = info.attributes.at("reduction").s();
return info.add_instruction(migraphx::make_op("scatternd_add"), args); if(not contains({"none", "add", "mul", "min", "max"}, reduction))
if(info.attributes.at("reduction").s() == "mul") {
return info.add_instruction(migraphx::make_op("scatternd_mul"), args); MIGRAPHX_THROW("PARSE_SCATTERND: unsupported reduction mode " + reduction);
}
} }
return info.add_instruction(migraphx::make_op("scatternd_none"), args); return info.add_instruction(migraphx::make_op("scatternd_" + reduction), args);
} }
}; };
......
...@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice> ...@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice>
void always_insert(instruction_ref arg) { op_args.insert(op_args.begin(), arg); } void always_insert(instruction_ref arg) { op_args.insert(op_args.begin(), arg); }
/**
* Either insert argument into `this->op_args` or return the constant value of the argument
*/
std::vector<int64_t> insert(instruction_ref arg) std::vector<int64_t> insert(instruction_ref arg)
{ {
std::vector<int64_t> result; std::vector<int64_t> result;
...@@ -144,16 +147,15 @@ struct parse_slice : op_parser<parse_slice> ...@@ -144,16 +147,15 @@ struct parse_slice : op_parser<parse_slice>
sd.op.axes = axes; sd.op.axes = axes;
} }
if(not sd.steps.empty()) if(std::any_of(sd.steps.begin(), sd.steps.end(), [](auto s) { return s != 1; }))
{ {
if(sd.op.starts.empty() or sd.op.ends.empty()) if(sd.op.starts.empty() or sd.op.ends.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable starts and ends is not supported"); MIGRAPHX_THROW(
"PARSE_SLICE: steps and variable starts and/or ends is not supported");
if(sd.op.axes.empty()) if(sd.op.axes.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported"); MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported");
} }
assert(sd.steps.empty() or sd.steps.size() == sd.op.axes.size());
// If any axes have negative step, prepare to add a "reverse" op // If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(sd.steps.size())) for(auto i : range(sd.steps.size()))
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <optional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
// generate unique output stream y, given input stream x;
//
// case unsorted:
// input x: [2, 1, 1, 3, 4, 3], attr_sorted = 0;
// output(s):
// y: [2, 1, 3, 4] --- the unique output
// y_indices: [0, 1, 3, 4] --- first incidence, in terms of indices of x
// x_rev_indices: [0, 1, 1, 2, 3, 2] --- x seen in terms of indices of y
// y_count: [1, 2, 2, 1] -- count at each y_index. sum = len(x)
//
// case sorted:
// input x: [2, 1, 1, 3, 4, 3], attr_sorted = 1;
// output(s):
// y: [1, 2, 3, 4] --- the unique output
// y_indices: [1, 0, 3, 4] --- first incidence, in terms of indices of x
// x_rev_indices: [1, 0, 0, 2, 3, 2] --- x seen in terms of indices of y
// y_count: [2, 1, 2, 1] -- count at each y_index. sum = len(x)
struct parse_unique : op_parser<parse_unique>
{
std::vector<op_desc> operators() const { return {{"Unique"}}; }
std::vector<instruction_ref> parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int64_t sorted = 1; // default = sorted.
if(contains(info.attributes, "sorted"))
sorted = parser.parse_value(info.attributes.at("sorted")).at<int>();
std::optional<int64_t> axis;
if(contains(info.attributes, "axis"))
{
auto n_dim = args[0]->get_shape().ndim();
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
axis = tune_axis(n_dim, *axis, opd.op_name);
}
migraphx::argument data_arg = args.back()->eval();
auto opr = axis ? migraphx::make_op("unique", {{"axis", *axis}, {"sorted", sorted}})
: migraphx::make_op("unique", {{"sorted", sorted}});
auto u_opr = info.add_instruction(opr, args.at(0));
auto i_y = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), u_opr);
auto i_y_idx = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), u_opr);
auto i_x_idx = info.add_instruction(make_op("get_tuple_elem", {{"index", 2}}), u_opr);
auto i_count = info.add_instruction(make_op("get_tuple_elem", {{"index", 3}}), u_opr);
return {i_y, i_y_idx, i_x_idx, i_count};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
value handle_pooling_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values)
{
auto kdims = in_shape.ndim() - 2;
if(starts_with(opd.onnx_name, "Global") or starts_with(opd.onnx_name, "QLinearGlobal"))
{
// if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
}
else
{
// works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
}
if(contains(info.attributes, "ceil_mode"))
{
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
if(contains(info.attributes, "dilations"))
{
values["dilations"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilations"]));
check_attr_sizes(
kdims, values["dilations"].size(), "PARSE_POOLING: inconsistent dilations");
}
// lp_order attribute
if(contains(info.attributes, "p"))
{
values["lp_order"] = info.attributes.at("p").i();
}
// ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
return values;
}
instruction_ref add_pooling_op(const op_desc& opd, onnx_parser::node_info info, instruction_ref l0)
{
std::string mode = opd.op_name;
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_pooling_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
}
if(paddings.size() != 2 * kdims)
{
paddings.resize(kdims * 2);
std::fill_n(paddings.begin(), 2 * kdims, 0);
}
if(values["padding"].size() != kdims)
{
values["padding"].resize(kdims);
std::fill_n(values["padding"].begin(), kdims, 0);
}
if(values["stride"].size() != kdims)
{
values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1);
}
if(values["dilations"].size() != kdims)
{
values["dilations"].resize(kdims);
std::fill_n(values["dilations"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings;
// TODO: add parsing for dilations
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
auto auto_pad = to_upper(info.attributes["auto_pad"].s());
// don't use the given padding sizes, if any
// values["padding"].clear();
if(in_shape.dynamic())
{
// set padding_mode to trigger auto padding at runtime
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
values["dilations"].to_vector<std::size_t>(),
in_shape.lens(),
paddings);
values["padding"] = paddings;
// default padding_mode indicates that padding sizes are not calculated dynamically
values["padding_mode"] = migraphx::op::padding_mode_t::default_;
}
}
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(not slice_start.empty())
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()});
// make an op just to get its output shape
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information
slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2,
out_lens.end(),
slice_start.begin(),
slice_end.begin(),
[](auto i, auto j) { return i + j; });
}
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
op.from_value(values);
auto l1 = info.add_instruction(op, l0);
if(not slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}), l1);
}
return l1;
}
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/float8.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#endif #endif
...@@ -144,6 +144,18 @@ struct npy_format_descriptor<half> ...@@ -144,6 +144,18 @@ struct npy_format_descriptor<half>
static constexpr auto name() { return _("half"); } static constexpr auto name() { return _("half"); }
}; };
template <>
struct npy_format_descriptor<migraphx::fp8::fp8e4m3fnuz>
{
static std::string format()
{
// following: https://docs.python.org/3/library/struct.html#format-characters
// TODO: need to figure out correct encoding
return "z";
}
static constexpr auto name() { return _("fp8e4m3fnuz"); }
};
} // namespace detail } // namespace detail
} // namespace pybind11 } // namespace pybind11
......
...@@ -56,7 +56,11 @@ target make_target(const std::string& name) ...@@ -56,7 +56,11 @@ target make_target(const std::string& name)
{ {
if(not contains(target_map(), name)) if(not contains(target_map(), name))
{ {
#ifdef _WIN32
std::string target_name = "migraphx_" + name + ".dll";
#else
std::string target_name = "libmigraphx_" + name + ".so"; std::string target_name = "libmigraphx_" + name + ".so";
#endif
store_target_lib(dynamic_loader(target_name)); store_target_lib(dynamic_loader(target_name));
} }
const auto it = target_map().find(name); const auto it = target_map().find(name);
......
...@@ -35,6 +35,110 @@ ...@@ -35,6 +35,110 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
static void replace_with_reduce(module& m, instruction_ref ins)
{
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
auto lens = s.lens();
std::vector<std::int64_t> axes(lens.size() - 2);
std::iota(axes.begin(), axes.end(), 2);
// average pooling
if(op.mode == op::pooling_mode::average)
{
m.replace_instruction(ins, make_op("reduce_mean", {{"axes", axes}}), ins->inputs());
}
// max pooling
else
{
m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs());
}
}
static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins)
{
// TODO remove this when MIOpen supports dilated pooling
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
// Ignore N, C axes
std::vector<size_t> dims = {s.lens().cbegin() + 2, s.lens().cend()};
bool default_padding =
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
if(not default_padding)
{
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// We need to pad both ends
dims[idx] += op.padding.at(idx) * 2;
}
}
std::vector<size_t> kernels = op.lengths;
std::vector<size_t> strides = op.stride;
std::vector<size_t> dilations = op.dilations;
std::vector<std::vector<int>> axis_indices;
axis_indices.resize(dims.size());
for(auto idx{0}; idx < dims.size(); ++idx)
{
// Only consider if iw fits into the window
for(size_t stride{0}; stride < dims.at(idx) - dilations.at(idx) * (kernels.at(idx) - 1);
stride += strides.at(idx))
{
for(size_t step{0}; step < kernels.at(idx); ++step)
{
axis_indices.at(idx).push_back(stride + dilations.at(idx) * step);
}
}
}
auto elements = ins->inputs().front();
if(not default_padding)
{
// Pad supports asym, we need to provide both ends
std::vector<size_t> padding(2 * s.lens().size(), 0);
// Format will be e.g {N, C, P1, P2, N, C, P1, P2}
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// Ignore N, C axes
padding.at(2 + idx) = op.padding.at(idx);
padding.at(2 + idx + s.lens().size()) = op.padding.at(idx);
}
// Default value needed for Max pooling
elements = m.insert_instruction(
ins,
make_op("pad", {{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
elements);
}
for(auto idx{0}; idx < axis_indices.size(); ++idx)
{
migraphx::shape s_indices{migraphx::shape::int32_type, {axis_indices.at(idx).size()}};
auto indices = m.add_literal(migraphx::literal{s_indices, axis_indices.at(idx)});
elements = m.insert_instruction(
ins, make_op("gather", {{"axis", idx + 2 /*ignore N,C*/}}), elements, indices);
}
// Ignore padding
std::vector<size_t> new_padding(kernels.size(), 0);
// The kernel window elements are places next to each other. E.g. {x1, y1, x2, y2, ...}
// We need to skip them to not overlap
std::vector<size_t> new_strides(kernels);
// Ignore dilations
std::vector<size_t> new_dilations(kernels.size(), 1);
m.replace_instruction(ins,
make_op("pooling",
{{"mode", op.mode},
{"padding", new_padding},
{"stride", new_strides},
{"lengths", kernels},
{"dilations", new_dilations}}),
elements);
}
void rewrite_pooling::apply(module& m) const void rewrite_pooling::apply(module& m) const
{ {
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
...@@ -43,26 +147,36 @@ void rewrite_pooling::apply(module& m) const ...@@ -43,26 +147,36 @@ void rewrite_pooling::apply(module& m) const
continue; continue;
if(ins->inputs().empty()) if(ins->inputs().empty())
continue; continue;
auto&& s = ins->inputs().front()->get_shape(); auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; })) bool same_kernel_as_shape = std::equal(
continue; s.lens().cbegin() + 2, s.lens().cend(), op.lengths.cbegin(), op.lengths.cend());
if(not std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; })) bool default_strides =
continue; std::all_of(op.stride.cbegin(), op.stride.cend(), [](auto i) { return i == 1; });
auto lens = s.lens(); bool default_padding =
if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end())) std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
continue; bool default_dilations =
std::vector<std::int64_t> axes(lens.size() - 2); std::all_of(op.dilations.cbegin(), op.dilations.cend(), [](auto i) { return i == 1; });
std::iota(axes.begin(), axes.end(), 2); if(same_kernel_as_shape and default_strides and default_padding and default_dilations)
// average pooling
if(op.mode == op::pooling_mode::average)
{ {
m.replace_instruction(ins, make_op("reduce_mean", {{"axes", axes}}), ins->inputs()); replace_with_reduce(m, ins);
} }
// max pooling else if(not default_dilations)
else
{ {
m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs()); // Dilated AvgPool with padding is not supported
if(not default_padding and op.mode == op::pooling_mode::average)
{
continue;
}
auto size =
std::accumulate(s.lens().cbegin(), s.lens().cend(), 1, std::multiplies<size_t>());
// Can't handle too much size because of literal size
if(size > 100000)
{
continue;
}
replace_dilations_with_gather_pooling(m, ins);
} }
} }
} }
......
...@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x); ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
} }
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale); auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div); auto add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);
if(ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp> #include <migraphx/iterator.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/simple_par_for.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/dom_info.hpp> #include <migraphx/dom_info.hpp>
...@@ -461,7 +461,7 @@ struct stream_info ...@@ -461,7 +461,7 @@ struct stream_info
std::back_inserter(index_to_ins), std::back_inserter(index_to_ins),
[](auto&& it) { return it.first; }); [](auto&& it) { return it.first; });
par_for(concur_ins.size(), [&](auto ins_index, auto tid) { simple_par_for(concur_ins.size(), [&](auto ins_index, auto tid) {
auto merge_first = index_to_ins[ins_index]; auto merge_first = index_to_ins[ins_index];
assert(concur_ins.count(merge_first) > 0); assert(concur_ins.count(merge_first) > 0);
auto& merge_second = concur_ins.at(merge_first); auto& merge_second = concur_ins.at(merge_first);
......
...@@ -941,15 +941,6 @@ struct find_splits ...@@ -941,15 +941,6 @@ struct find_splits
{ {
auto split = i->inputs()[split_idx]; auto split = i->inputs()[split_idx];
assert(split->name() == "slice"); assert(split->name() == "slice");
// Insert contiguous for reshapes
auto outputs = i->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), i);
m.replace_instruction(output, output->get_operator(), x);
}
m.replace_instruction(i, split->get_operator(), c); m.replace_instruction(i, split->get_operator(), c);
} }
...@@ -1181,13 +1172,6 @@ struct find_conv_dot_horiz_fusion ...@@ -1181,13 +1172,6 @@ struct find_conv_dot_horiz_fusion
for(auto arg : range(start, last)) for(auto arg : range(start, last))
{ {
auto outputs = arg->outputs(); auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
int64_t len = arg->get_shape().lens()[axis]; int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction( m.replace_instruction(
...@@ -1487,11 +1471,6 @@ struct find_split_reshape ...@@ -1487,11 +1471,6 @@ struct find_split_reshape
slc_axis_len; slc_axis_len;
}); });
// insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
{
input = m.insert_instruction(std::next(input), make_op("contiguous"), input);
}
auto rsp_ins = m.insert_instruction( auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
......
...@@ -22,8 +22,10 @@ ...@@ -22,8 +22,10 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/simplify_dyn_ops.hpp> #include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -32,6 +34,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -32,6 +34,10 @@ inline namespace MIGRAPHX_INLINE_NS {
* Convert 2 input static shape broadcast/multibroadcast into 1 input version. * Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions * Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators. * of the broadcasting operators.
* From:
* broadcast_op(argument_with_static_shape, argument_with_static_shape)
* To:
* broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims
*/ */
struct find_static_2in_broadcasts struct find_static_2in_broadcasts
{ {
...@@ -60,8 +66,65 @@ struct find_static_2in_broadcasts ...@@ -60,8 +66,65 @@ struct find_static_2in_broadcasts
}; };
/** /**
* Simplify slice with variable `starts` and `ends` to the constant version if * Simplify slice with 2 inputs to the 1 input version if inputs[1] is constant.
* the `input_starts` and `input_ends` inputs are constant. * From:
* slice(data, constant_input); two attributes set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_2in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(2), match::arg(1)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
auto slice_op = any_cast<op::slice>(ins->get_operator());
auto set_attrs = slice_op.get_set_attributes();
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
if(set_attrs == op::slice::ends_axes)
{
// slice(data, starts)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_vec = slice_op.ends;
axes_vec = slice_op.axes;
}
else if(set_attrs == op::slice::starts_axes)
{
// slice(data, ends)
inputs.at(1)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
axes_vec = slice_op.axes;
}
else
{
// slice(data, axes)
inputs.at(1)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
ends_vec = slice_op.ends;
}
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
};
/**
* Simplify slice with 3 inputs to the 1 input version if inputs[1:2] are constant.
* From:
* slice(data, constant_input1, constant_input2); one attribute set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/ */
struct find_const_3in_slice struct find_const_3in_slice
{ {
...@@ -76,27 +139,51 @@ struct find_const_3in_slice ...@@ -76,27 +139,51 @@ struct find_const_3in_slice
{ {
auto ins = mr.result; auto ins = mr.result;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval(); auto slice_op = any_cast<op::slice>(ins->get_operator());
argument ends_arg = inputs.at(2)->eval(); auto set_attrs = slice_op.get_set_attributes();
if(not starts_arg.empty() and not ends_arg.empty()) std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
if(set_attrs == op::slice::axes_only)
{ {
std::vector<int64_t> starts_vec; // slice(data, starts, ends)
std::vector<int64_t> ends_vec; inputs.at(1)->eval().visit(
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); }); [&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); }); inputs.at(2)->eval().visit(
auto slice_val = ins->get_operator().to_value(); [&](auto output) { ends_vec.assign(output.begin(), output.end()); });
auto axes_vec = slice_val.at("axes").to_vector<int64_t>(); axes_vec = slice_op.axes;
m.replace_instruction( }
ins, else if(set_attrs == op::slice::ends_only)
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}), {
inputs.at(0)); // slice(data, starts, axes)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
ends_vec = slice_op.ends;
}
else
{
// slice(data, ends, axes)
inputs.at(1)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
} }
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
} }
}; };
/** /**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if * Simplify slice with 4 inputs to the 1 input version if inputs[1:3] are constant.
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant. * From:
* slice(data, constant_starts, constant_ends, constant_axes)
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/ */
struct find_const_4in_slice struct find_const_4in_slice
{ {
...@@ -112,9 +199,9 @@ struct find_const_4in_slice ...@@ -112,9 +199,9 @@ struct find_const_4in_slice
{ {
auto ins = mr.result; auto ins = mr.result;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval(); argument starts_arg = inputs.at(1)->eval(false);
argument ends_arg = inputs.at(2)->eval(); argument ends_arg = inputs.at(2)->eval(false);
argument axes_arg = inputs.at(3)->eval(); argument axes_arg = inputs.at(3)->eval(false);
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty()) if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
{ {
std::vector<int64_t> starts_vec; std::vector<int64_t> starts_vec;
...@@ -131,10 +218,116 @@ struct find_const_4in_slice ...@@ -131,10 +218,116 @@ struct find_const_4in_slice
} }
}; };
/**
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct find_static_dimensions_of
{
auto matcher() const { return match::name("dimensions_of")(); }
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto input = ins->inputs().at(0);
auto dimensions_of_value = ins->get_operator().to_value();
auto start = dimensions_of_value.at("start").to<std::size_t>();
auto end = dimensions_of_value.at("end").to<std::size_t>();
if(input->get_shape().dynamic())
{
// check if dynamic dimensions from start to end are fixed
auto dds = input->get_shape().dyn_dims();
if(std::any_of(dds.begin() + start, dds.begin() + end, [](auto dd) {
return not dd.is_fixed();
}))
{
return;
}
}
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input->get_shape().to_static(1).lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}};
auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape});
m.replace_instruction(ins, lit_ins);
}
};
/**
* Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1
* argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes.
* This matcher can be generalized to matching reshape(data, static_shape_output_tensor).
* From:
* x = allocate(constant_output_dims) -> reshape(data, x)
* To:
* reshape(data); reshape.dims = constant_output_dims
*/
struct find_const_alloc_reshapes
{
auto matcher() const
{
return match::name("reshape")(match::nargs(2),
match::arg(1)(match::name("allocate")(match::is_constant())));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto reshape_ins = mr.result;
auto reshape_inputs = reshape_ins->inputs();
auto alloc_ins = reshape_inputs.at(1);
argument output_dims_arg = alloc_ins->inputs().at(0)->eval(false);
std::vector<int64_t> output_dims_vec;
output_dims_arg.visit(
[&](auto output) { output_dims_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
reshape_ins, make_op("reshape", {{"dims", output_dims_vec}}), reshape_inputs.at(0));
// have dead_code_elimination remove the previous allocate
}
};
/**
* Simplify allocate into fill operator that has constant output dimensions and constant value.
* The allocate into fill instructions is what is produced when parsing the ONNX
* ConstantOfShape operator. This replacement could be handled with propagate_constant, but
* would rather have the simplification happen earlier during compiling.
* This matcher can be generalized to matching fill(constant_value, static_shape_output_tensor).
* From:
* x = allocate(constant_ouptut_dims) -> fill(constant_value, x)
* To:
* literal
*/
struct find_const_alloc_fill
{
auto matcher() const
{
return match::name("fill")(match::arg(0)(match::is_constant()),
match::arg(1)(match::name("allocate")(match::is_constant())));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto fill_ins = mr.result;
auto fill_arg = fill_ins->eval(false);
auto l = m.add_literal(fill_arg.get_shape(), fill_arg.data());
m.replace_instruction(fill_ins, l);
}
};
void simplify_dyn_ops::apply(module& m) const void simplify_dyn_ops::apply(module& m) const
{ {
match::find_matches( match::find_matches(m,
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{}); find_static_dimensions_of{},
find_const_alloc_reshapes{},
find_static_2in_broadcasts{},
find_const_2in_slice{},
find_const_3in_slice{},
find_const_4in_slice{},
find_const_alloc_fill{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment