Commit 4ea39116 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 20128cae d8011adf
......@@ -64,6 +64,7 @@ shape compute_padded_shape(const shape& input,
// Used for dynamic auto padding of pooling operators where padding needs to be computed at
// evaulation time.
MIGRAPHX_EXPORT
shape compute_padded_pool_shape(const shape& input,
const shape& kernel,
const std::vector<std::size_t>& padding,
......
......@@ -31,6 +31,7 @@
#include <migraphx/module.hpp>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <array>
#include <string>
namespace migraphx {
......
......@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#include <cstdint>
#include <migraphx/config.hpp>
#if defined(CPPCHECK)
......
......@@ -30,6 +30,7 @@
#include <migraphx/rank.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>
#include <migraphx/optional.hpp>
#include <vector>
namespace migraphx {
......@@ -68,6 +69,19 @@ auto stream_write_value_impl(rank<1>, std::ostream& os, const T& x) -> decltype(
os << x;
}
template <class T>
auto stream_write_value_impl(rank<1>, std::ostream& os, const optional<T>& x)
{
if(x.has_value())
{
os << *x;
}
else
{
os << "nullopt";
}
}
template <class T>
void stream_write_value_impl(rank<1>, std::ostream& os, const std::vector<T>& r)
{
......
......@@ -34,6 +34,7 @@ struct MIGRAPHX_EXPORT tmp_dir
{
fs::path path;
tmp_dir(const std::string& prefix = "");
tmp_dir(tmp_dir&&) = default;
void execute(const std::string& exe, const std::string& args) const;
......
......@@ -34,7 +34,7 @@ template <class PrivateMigraphTypeNameProbe>
std::string compute_type_name()
{
std::string name;
#ifdef _MSC_VER
#if defined(_MSC_VER) && !defined(__clang__)
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
......
......@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
if(input_shape.dynamic())
{
// return the unchanged `vec` if the dynamic_dimensions at `axes` are not fixed
if(std::any_of(axes.begin(), axes.end(), [&](auto ax) {
return not input_shape.dyn_dims().at(ax).is_fixed();
}))
{
return vec;
}
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
const auto& dd = input_shape.dyn_dims().at(i);
if(not dd.is_fixed())
{
MIGRAPHX_THROW(
"NORMALIZE_ATTR: 'use_lens' on a non-fixed dynamic dimension, axis=" +
std::to_string(i));
}
return dd.max;
return input_shape.dyn_dims().at(i).max;
});
}
else
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,58 +21,56 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace onnx {
shape pack_int8_shape(const shape& s)
// This method is to prep for quantizelinear or dequantizelinear operation for
// either the broadcasting of weight-scale or zero-points of qlinearadd operator
// outputs: operator op (inputs x, broadcasted: scale (float) & zero_pt (8-bit))
instruction_ref bcast_qdq_instr(const std::string& op_name,
instruction_ref x_in,
instruction_ref arg_fscale,
instruction_ref arg_z_pt,
const onnx_parser::node_info& info)
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type");
}
auto in_lens = x_in->get_shape().lens();
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
// prep 1: broadcast scale. it can come as a scalar or a 1-D tensor.
instruction_ref bcast_scale;
if(arg_fscale->get_shape().elements() > 1)
bcast_scale = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", in_lens}}), arg_fscale);
else
bcast_scale = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", in_lens}}), arg_fscale);
return {s.type(), lens, strides};
}
// prep 2: broadcast zero point. it can come as a scalar or a 1-D tensor.
instruction_ref bcast_zero_pt;
if(arg_z_pt->get_shape().elements() > 1)
bcast_zero_pt = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", in_lens}}), arg_z_pt);
else
bcast_zero_pt = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", in_lens}}), arg_z_pt);
shape miopen_int8_conv_pack::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).standard();
return pack_int8_shape(inputs.at(0));
// op_name is either quantizelinear or dequantizelinear:
return info.add_instruction(migraphx::make_op(op_name), x_in, bcast_scale, bcast_zero_pt);
}
argument
miopen_int8_conv_pack::compute(context& ctx, const shape&, const std::vector<argument>& args) const
// Multibroadcast a scaler..
instruction_ref bcast_scalar_instr(const migraphx::shape& shape_out,
instruction_ref arg_in,
const onnx_parser::node_info& info)
{
auto arg_desc = make_tensor(args[0].get_shape());
auto arg_desc_vec4 = make_tensor(args[0].get_shape(), true);
float alpha = 1;
float beta = 0;
// pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
arg_desc.get(),
args[0].implicit(),
&beta,
arg_desc_vec4.get(),
args[1].implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("INT8_CONV_PACK: transform input tensor failed");
}
return args[1];
auto bcast_instr_out = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", shape_out.lens()}}), arg_in);
return bcast_instr_out;
}
} // namespace gpu
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,42 +21,35 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <utility>
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_BROADCAST_QDQ_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_BROADCAST_QDQ_HPP
#include <string>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_int8_gemm_pack_a
{
std::string name() const { return "gpu::int8_gemm_pack_a"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_int8_gemm_pack_b
{
std::string name() const { return "gpu::int8_gemm_pack_b"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
namespace onnx {
// This method is to prep for quantizelinear or dequantizelinear operation for
// either the broadcasting of weight-scale or zero-points of qlinearadd operator
// outputs: operator op (inputs x, broadcasted: scale (float) & zero_pt (8-bit))
instruction_ref bcast_qdq_instr(const std::string& op_name,
instruction_ref x_in,
instruction_ref arg_fscale,
instruction_ref arg_z_pt,
const onnx_parser::node_info& info);
// Multibroadcast a scaler..
instruction_ref bcast_scalar_instr(const migraphx::shape& shape_out,
instruction_ref arg_in,
const onnx_parser::node_info& info);
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -97,10 +97,11 @@ struct onnx_parser
shape::dynamic_dimension default_dyn_dim_value = {1, 1};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false;
bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10;
int64_t opset_version = 13;
bool use_dyn_output = false;
bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10;
int64_t limit_max_iterations = std::numeric_limits<uint16_t>::max();
int64_t opset_version = 13;
std::unordered_map<std::string, op_func> ops;
......
......@@ -67,6 +67,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
}
parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations;
parser.limit_max_iterations = options.limit_max_iterations;
parser.use_dyn_output = options.use_dyn_output;
if(options.print_program_on_error)
......
......@@ -244,7 +244,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
this->filename = std::move(name);
auto parent_path = fs::path(this->filename).parent_path();
if(not parent_path.empty())
this->path = parent_path;
this->path = parent_path.string();
onnx::ModelProto model;
if(model.ParseFromIstream(&is))
......
......@@ -47,7 +47,7 @@ void cal_auto_padding_size(onnx_parser::node_info info,
return;
}
auto auto_pad = info.attributes["auto_pad"].s();
auto auto_pad = to_upper(info.attributes["auto_pad"].s());
if(auto_pad.find("SAME") != std::string::npos)
{
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -50,14 +50,25 @@ struct parse_arg_op : op_parser<parse_arg_op>
keep_dims = parser.parse_value(info.attributes.at("keepdims")).at<int>();
}
bool select_last_index = false;
if(contains(info.attributes, "select_last_index"))
{
select_last_index = static_cast<bool>(
parser.parse_value(info.attributes.at("select_last_index")).at<int>());
}
if(keep_dims == 0)
{
auto ins = info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args);
auto ins = info.add_instruction(
make_op(opd.op_name, {{"axis", axis}, {"select_last_index", select_last_index}}),
args);
return info.add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
}
else
{
return info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args);
return info.add_instruction(
make_op(opd.op_name, {{"axis", axis}, {"select_last_index", select_last_index}}),
args);
}
}
};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......
......@@ -87,8 +87,7 @@ struct parse_depthtospace : op_parser<parse_depthtospace>
auto temp1 = info.add_instruction(make_op("reshape", {{"dims", lens1}}), args[0]);
auto temp2 = info.add_instruction(make_op("transpose", {{"permutation", perm}}), temp1);
return info.add_instruction(make_op("reshape", {{"dims", lens2}}),
info.make_contiguous(temp2));
return info.add_instruction(make_op("reshape", {{"dims", lens2}}), temp2);
}
};
......
......@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Neg", "neg"},
{"Reciprocal", "recip"},
{"Relu", "relu"},
{"Round", "round"},
{"Round", "nearbyint"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"Sin", "sin"},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_groupnorm : op_parser<parse_groupnorm>
{
std::vector<op_desc> operators() const { return {{"GroupNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
size_t num_groups;
if(contains(info.attributes, "num_groups"))
{
num_groups = parser.parse_value(info.attributes.at("num_groups")).at<size_t>();
}
else
{
MIGRAPHX_THROW("PARSE_GROUPNORM: num_groups must be available");
}
if(args.size() != 3)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input count");
}
auto x = args.at(0);
auto scale = args.at(1);
auto bias = args.at(2);
auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
auto x_dims = x_shape.lens();
if(x_shape.ndim() <= 2)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input shape");
}
auto c = x_shape.lens().at(1);
if(c % num_groups != 0)
{
MIGRAPHX_THROW(
"PARSE_GROUPNORM: num_groups should be a divisor of the number of channels");
}
auto group_size = c / num_groups;
if(scale->get_shape().ndim() != 1 or scale->get_shape().lens().at(0) != num_groups)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: scale tensor shape should be num_groups");
}
if(bias->get_shape().ndim() != 1 or bias->get_shape().lens().at(0) != num_groups)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: bias tensor shape should be num_groups");
}
// Original shape: N x C x D1 x ... x Dn
// New shape: N x num_groups x C // num_groups x D1 x ... x Dn
std::vector<size_t> dims = {x_dims.at(0), num_groups, group_size};
std::copy(x_dims.begin() + 2, x_dims.end(), std::back_inserter(dims));
auto x_reshaped = info.add_instruction(make_op("reshape", {{"dims", dims}}), x);
// Axes for D1 x ... x Dn
std::vector<size_t> axes(dims.size() - 2);
std::iota(axes.begin(), axes.end(), 2);
// y = (x - mean) * rsqrt(variance + epsilon) * scale + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_reshaped);
auto x_sub_mean = info.add_common_op("sub", x_reshaped, mean);
auto x_sqdiff_mean = info.add_common_op("sqdiff", x_reshaped, mean);
auto variance =
info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_sqdiff_mean);
epsilon =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}});
auto var_eps = info.add_common_op("add", variance, eps);
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto result = info.add_common_op("mul", x_sub_mean, rsqrt);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast);
auto y = info.add_instruction(make_op("add"), scaled, bias_bcast);
auto y_reshaped = info.add_instruction(make_op("reshape", {{"dims", x_dims}}), y);
return y_reshaped;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_isinf : op_parser<parse_isinf>
{
std::vector<op_desc> operators() const { return {{"IsInf", "isinf"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
bool detect_negative = true;
bool detect_positive = true;
if(contains(info.attributes, "detect_negative"))
{
detect_negative = static_cast<bool>(
parser.parse_value(info.attributes.at("detect_negative")).at<int>());
}
if(contains(info.attributes, "detect_positive"))
{
detect_positive = static_cast<bool>(
parser.parse_value(info.attributes.at("detect_positive")).at<int>());
}
auto x_shape = args[0]->get_shape();
if(not detect_negative and not detect_positive)
{
return info.add_instruction(
make_op("multibroadcast", {{"out_lens", x_shape.lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{shape::bool_type}, {false}}));
}
auto is_inf = info.add_instruction(make_op("isinf"), args[0]);
if(detect_negative and detect_positive)
{
return is_inf;
}
auto zero_l = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}});
auto mb_zero =
info.add_instruction(make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), zero_l);
auto cond = info.add_broadcastable_binary_op(
detect_negative ? "less" : "greater", args[0], mb_zero);
if(cond->get_shape().type() != shape::bool_type)
{
cond =
info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), cond);
}
return info.add_instruction(make_op("logical_and"), is_inf, cond);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_layernorm : op_parser<parse_layernorm>
{
std::vector<op_desc> operators() const { return {{"LayerNormalization"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
}
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(info.attributes, "stash_type"))
{
std::cerr << "WARNING: LAYERNORM does not support stash_type, it will be ignored.\n";
}
if(args.size() < 2 or args.size() > 3)
{
MIGRAPHX_THROW("PARSE_LAYERNORM: invalid input count");
}
auto x = args.at(0);
auto scale = args.at(1);
bool skip_bias = args.size() == 2;
instruction_ref bias;
if(not skip_bias)
{
bias = args.at(2);
}
auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
int64_t x_rank = x_shape.ndim();
if(x_rank < 2)
{
MIGRAPHX_THROW("PARSE_LAYERNORM: invalid input shape");
}
// If rank(X) is r, axis' allowed range is [-r, r)
if(axis < -x_rank or axis >= x_rank)
{
MIGRAPHX_THROW("PARSE_LAYERNORM: invalid axis");
}
// y = (x - mean) * rsqrt(variance + epsilon) * scale + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// axis can be negative
axis = axis < 0 ? axis + x_rank : axis;
auto kdims = x_rank - axis;
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), axis);
auto skipped_axes = x_rank - kdims;
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto x_sub_mean = info.add_common_op("sub", x, mean);
auto x_sqdiff_mean = info.add_common_op("sqdiff", x, mean);
auto variance =
info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_sqdiff_mean);
epsilon =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}});
auto var_eps = info.add_common_op("add", variance, eps);
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto result = info.add_common_op("mul", x_sub_mean, rsqrt);
instruction_ref scale_bcast = scale;
instruction_ref bias_bcast = bias;
if(skipped_axes > 0)
{
auto x_dims = x_shape.lens();
scale_bcast = info.add_instruction(
make_op("broadcast", {{"axis", skipped_axes}, {"out_lens", x_dims}}), scale);
if(not skip_bias)
{
bias_bcast = info.add_instruction(
make_op("broadcast", {{"axis", skipped_axes}, {"out_lens", x_dims}}), bias);
}
}
auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast);
auto y = skip_bias ? scaled : info.add_instruction(make_op("add"), scaled, bias_bcast);
return {y, mean, rsqrt};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment