Unverified Commit a6fa5e4b authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into enable_navi_32_ci

parents b7a7cd3c 7604ecf5
...@@ -66,7 +66,7 @@ struct scatter : op_name<Derived> ...@@ -66,7 +66,7 @@ struct scatter : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3).standard(); check_shapes{inputs, *this}.has(3);
// If non-packed, this converts to a packed output while preserving permutation of tensor // If non-packed, this converts to a packed output while preserving permutation of tensor
return inputs.front().with_lens(inputs.front().lens()); return inputs.front().with_lens(inputs.front().lens());
} }
......
...@@ -29,6 +29,17 @@ ...@@ -29,6 +29,17 @@
#if defined(CPPCHECK) #if defined(CPPCHECK)
#define MIGRAPHX_HAS_OPTIONAL 1 #define MIGRAPHX_HAS_OPTIONAL 1
#define MIGRAPHX_HAS_OPTIONAL_TS 1 #define MIGRAPHX_HAS_OPTIONAL_TS 1
#elif defined(_WIN32)
#if _MSC_VER >= 1920
#define MIGRAPHX_HAS_OPTIONAL 1
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#elif _MSC_VER >= 1900
#define MIGRAPHX_HAS_OPTIONAL 0
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#elif defined(__has_include) #elif defined(__has_include)
#if __has_include(<optional>) && __cplusplus >= 201703L #if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1 #define MIGRAPHX_HAS_OPTIONAL 1
......
...@@ -64,6 +64,7 @@ shape compute_padded_shape(const shape& input, ...@@ -64,6 +64,7 @@ shape compute_padded_shape(const shape& input,
// Used for dynamic auto padding of pooling operators where padding needs to be computed at // Used for dynamic auto padding of pooling operators where padding needs to be computed at
// evaulation time. // evaulation time.
MIGRAPHX_EXPORT
shape compute_padded_pool_shape(const shape& input, shape compute_padded_pool_shape(const shape& input,
const shape& kernel, const shape& kernel,
const std::vector<std::size_t>& padding, const std::vector<std::size_t>& padding,
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <array>
#include <string> #include <string>
namespace migraphx { namespace migraphx {
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP #define MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#include <cstdint>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#if defined(CPPCHECK) #if defined(CPPCHECK)
......
...@@ -34,6 +34,7 @@ struct MIGRAPHX_EXPORT tmp_dir ...@@ -34,6 +34,7 @@ struct MIGRAPHX_EXPORT tmp_dir
{ {
fs::path path; fs::path path;
tmp_dir(const std::string& prefix = ""); tmp_dir(const std::string& prefix = "");
tmp_dir(tmp_dir&&) = default;
void execute(const std::string& exe, const std::string& args) const; void execute(const std::string& exe, const std::string& args) const;
......
...@@ -34,7 +34,7 @@ template <class PrivateMigraphTypeNameProbe> ...@@ -34,7 +34,7 @@ template <class PrivateMigraphTypeNameProbe>
std::string compute_type_name() std::string compute_type_name()
{ {
std::string name; std::string name;
#ifdef _MSC_VER #if defined(_MSC_VER) && !defined(__clang__)
name = typeid(PrivateMigraphTypeNameProbe).name(); name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7); name = name.substr(7);
#else #else
......
...@@ -90,8 +90,7 @@ struct not_finite_fn ...@@ -90,8 +90,7 @@ struct not_finite_fn
template <class T> template <class T>
bool operator()(T x) const bool operator()(T x) const
{ {
using std::isfinite; return not std::isfinite(static_cast<double>(x));
return not isfinite(x);
} }
}; };
static constexpr not_finite_fn not_finite{}; static constexpr not_finite_fn not_finite{};
...@@ -101,8 +100,7 @@ struct compare_mag_fn ...@@ -101,8 +100,7 @@ struct compare_mag_fn
template <class T, class U> template <class T, class U>
bool operator()(T x, U y) const bool operator()(T x, U y) const
{ {
using std::fabs; return std::fabs(x) < std::fabs(y);
return fabs(x) < fabs(y);
} }
}; };
static constexpr compare_mag_fn compare_mag{}; static constexpr compare_mag_fn compare_mag{};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/broadcast_qdq.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
// This method is to prep for quantizelinear or dequantizelinear operation for
// either the broadcasting of weight-scale or zero-points of qlinearadd operator
// outputs: operator op (inputs x, broadcasted: scale (float) & zero_pt (8-bit))
instruction_ref bcast_qdq_instr(const std::string& op_name,
instruction_ref x_in,
instruction_ref arg_fscale,
instruction_ref arg_z_pt,
const onnx_parser::node_info& info)
{
auto in_lens = x_in->get_shape().lens();
// prep 1: broadcast scale. it can come as a scalar or a 1-D tensor.
instruction_ref bcast_scale;
if(arg_fscale->get_shape().elements() > 1)
bcast_scale = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", in_lens}}), arg_fscale);
else
bcast_scale = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", in_lens}}), arg_fscale);
// prep 2: broadcast zero point. it can come as a scalar or a 1-D tensor.
instruction_ref bcast_zero_pt;
if(arg_z_pt->get_shape().elements() > 1)
bcast_zero_pt = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", in_lens}}), arg_z_pt);
else
bcast_zero_pt = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", in_lens}}), arg_z_pt);
// op_name is either quantizelinear or dequantizelinear:
return info.add_instruction(migraphx::make_op(op_name), x_in, bcast_scale, bcast_zero_pt);
}
// Multibroadcast a scaler..
instruction_ref bcast_scalar_instr(const migraphx::shape& shape_out,
instruction_ref arg_in,
const onnx_parser::node_info& info)
{
auto bcast_instr_out = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", shape_out.lens()}}), arg_in);
return bcast_instr_out;
}
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_BROADCAST_QDQ_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_BROADCAST_QDQ_HPP
#include <string>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
// This method is to prep for quantizelinear or dequantizelinear operation for
// either the broadcasting of weight-scale or zero-points of qlinearadd operator
// outputs: operator op (inputs x, broadcasted: scale (float) & zero_pt (8-bit))
instruction_ref bcast_qdq_instr(const std::string& op_name,
instruction_ref x_in,
instruction_ref arg_fscale,
instruction_ref arg_z_pt,
const onnx_parser::node_info& info);
// Multibroadcast a scaler..
instruction_ref bcast_scalar_instr(const migraphx::shape& shape_out,
instruction_ref arg_in,
const onnx_parser::node_info& info);
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -244,7 +244,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name) ...@@ -244,7 +244,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
this->filename = std::move(name); this->filename = std::move(name);
auto parent_path = fs::path(this->filename).parent_path(); auto parent_path = fs::path(this->filename).parent_path();
if(not parent_path.empty()) if(not parent_path.empty())
this->path = parent_path; this->path = parent_path.string();
onnx::ModelProto model; onnx::ModelProto model;
if(model.ParseFromIstream(&is)) if(model.ParseFromIstream(&is))
......
...@@ -47,7 +47,7 @@ void cal_auto_padding_size(onnx_parser::node_info info, ...@@ -47,7 +47,7 @@ void cal_auto_padding_size(onnx_parser::node_info info,
return; return;
} }
auto auto_pad = info.attributes["auto_pad"].s(); auto auto_pad = to_upper(info.attributes["auto_pad"].s());
if(auto_pad.find("SAME") != std::string::npos) if(auto_pad.find("SAME") != std::string::npos)
{ {
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos); bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -50,14 +50,25 @@ struct parse_arg_op : op_parser<parse_arg_op> ...@@ -50,14 +50,25 @@ struct parse_arg_op : op_parser<parse_arg_op>
keep_dims = parser.parse_value(info.attributes.at("keepdims")).at<int>(); keep_dims = parser.parse_value(info.attributes.at("keepdims")).at<int>();
} }
bool select_last_index = false;
if(contains(info.attributes, "select_last_index"))
{
select_last_index = static_cast<bool>(
parser.parse_value(info.attributes.at("select_last_index")).at<int>());
}
if(keep_dims == 0) if(keep_dims == 0)
{ {
auto ins = info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args); auto ins = info.add_instruction(
make_op(opd.op_name, {{"axis", axis}, {"select_last_index", select_last_index}}),
args);
return info.add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins); return info.add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
} }
else else
{ {
return info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args); return info.add_instruction(
make_op(opd.op_name, {{"axis", axis}, {"select_last_index", select_last_index}}),
args);
} }
} }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_castlike : op_parser<parse_castlike>
{
std::vector<op_desc> operators() const { return {{"CastLike"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
if(not(args.size() == 2))
{
MIGRAPHX_THROW("PARSE_CASTLIKE: CastLike must have exactly 2 inputs!");
}
shape::type_t target_type = args[1]->get_shape().type();
return info.add_instruction(make_op("convert", {{"target_type", target_type}}), args[0]);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -49,6 +49,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape> ...@@ -49,6 +49,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
{ {
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!"); MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
} }
// convert to a scalar literal
l_val = literal(shape{l_val.get_shape().type(), {1}, {0}}, l_val.data());
} }
else else
{ {
...@@ -64,30 +66,37 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape> ...@@ -64,30 +66,37 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
migraphx::shape s; migraphx::shape s;
// input is empty, output is a scalar // input is empty, output is a scalar
auto type = l_val.get_shape().type(); auto type = l_val.get_shape().type();
// empty input tensor, output is a scalar migraphx::argument input = args[0]->eval();
if(args[0]->get_shape().elements() == 0) if(not input.empty())
{ {
s = migraphx::shape{type, {1}, {0}}; // empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
}
else
{
std::vector<std::size_t> dims;
input.visit([&](auto ia) { dims.assign(ia.begin(), ia.end()); });
s = migraphx::shape{type, dims};
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
});
return info.add_literal(l_out);
} }
// has variable input (dynamic shape buffer)
else else
{ {
migraphx::argument in = args[0]->eval(); auto dv_lit = info.add_literal(l_val);
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported"); auto alloc_ins =
info.add_instruction(make_op("allocate", {{"buf_type", type}}), args[0]);
std::vector<std::size_t> dims; return info.add_instruction(make_op("fill"), dv_lit, alloc_ins);
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims};
} }
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
});
return info.add_literal(l_out);
} }
} }
}; };
......
...@@ -87,8 +87,7 @@ struct parse_depthtospace : op_parser<parse_depthtospace> ...@@ -87,8 +87,7 @@ struct parse_depthtospace : op_parser<parse_depthtospace>
auto temp1 = info.add_instruction(make_op("reshape", {{"dims", lens1}}), args[0]); auto temp1 = info.add_instruction(make_op("reshape", {{"dims", lens1}}), args[0]);
auto temp2 = info.add_instruction(make_op("transpose", {{"permutation", perm}}), temp1); auto temp2 = info.add_instruction(make_op("transpose", {{"permutation", perm}}), temp1);
return info.add_instruction(make_op("reshape", {{"dims", lens2}}), return info.add_instruction(make_op("reshape", {{"dims", lens2}}), temp2);
info.make_contiguous(temp2));
} }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_groupnorm : op_parser<parse_groupnorm>
{
std::vector<op_desc> operators() const { return {{"GroupNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
size_t num_groups;
if(contains(info.attributes, "num_groups"))
{
num_groups = parser.parse_value(info.attributes.at("num_groups")).at<size_t>();
}
else
{
MIGRAPHX_THROW("PARSE_GROUPNORM: num_groups must be available");
}
if(args.size() != 3)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input count");
}
auto x = args.at(0);
auto scale = args.at(1);
auto bias = args.at(2);
auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
auto x_dims = x_shape.lens();
if(x_shape.ndim() <= 2)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input shape");
}
auto c = x_shape.lens().at(1);
if(c % num_groups != 0)
{
MIGRAPHX_THROW(
"PARSE_GROUPNORM: num_groups should be a divisor of the number of channels");
}
auto group_size = c / num_groups;
if(scale->get_shape().ndim() != 1 or scale->get_shape().lens().at(0) != num_groups)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: scale tensor shape should be num_groups");
}
if(bias->get_shape().ndim() != 1 or bias->get_shape().lens().at(0) != num_groups)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: bias tensor shape should be num_groups");
}
// Original shape: N x C x D1 x ... x Dn
// New shape: N x num_groups x C // num_groups x D1 x ... x Dn
std::vector<size_t> dims = {x_dims.at(0), num_groups, group_size};
std::copy(x_dims.begin() + 2, x_dims.end(), std::back_inserter(dims));
auto x_reshaped = info.add_instruction(make_op("reshape", {{"dims", dims}}), x);
// Axes for D1 x ... x Dn
std::vector<size_t> axes(dims.size() - 2);
std::iota(axes.begin(), axes.end(), 2);
// y = (x - mean) * rsqrt(variance + epsilon) * scale + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_reshaped);
auto x_sub_mean = info.add_common_op("sub", x_reshaped, mean);
auto x_sqdiff_mean = info.add_common_op("sqdiff", x_reshaped, mean);
auto variance =
info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_sqdiff_mean);
epsilon =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}});
auto var_eps = info.add_common_op("add", variance, eps);
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto result = info.add_common_op("mul", x_sub_mean, rsqrt);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast);
auto y = info.add_instruction(make_op("add"), scaled, bias_bcast);
auto y_reshaped = info.add_instruction(make_op("reshape", {{"dims", x_dims}}), y);
return y_reshaped;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_layernorm : op_parser<parse_layernorm>
{
std::vector<op_desc> operators() const { return {{"LayerNormalization"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
}
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(info.attributes, "stash_type"))
{
std::cerr << "WARNING: LAYERNORM does not support stash_type, it will be ignored.\n";
}
if(args.size() < 2 or args.size() > 3)
{
MIGRAPHX_THROW("PARSE_LAYERNORM: invalid input count");
}
auto x = args.at(0);
auto scale = args.at(1);
bool skip_bias = args.size() == 2;
instruction_ref bias;
if(not skip_bias)
{
bias = args.at(2);
}
auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
int64_t x_rank = x_shape.ndim();
if(x_rank < 2)
{
MIGRAPHX_THROW("PARSE_LAYERNORM: invalid input shape");
}
// If rank(X) is r, axis' allowed range is [-r, r)
if(axis < -x_rank or axis >= x_rank)
{
MIGRAPHX_THROW("PARSE_LAYERNORM: invalid axis");
}
// y = (x - mean) * rsqrt(variance + epsilon) * scale + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// axis can be negative
axis = axis < 0 ? axis + x_rank : axis;
auto kdims = x_rank - axis;
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), axis);
auto skipped_axes = x_rank - kdims;
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto x_sub_mean = info.add_common_op("sub", x, mean);
auto x_sqdiff_mean = info.add_common_op("sqdiff", x, mean);
auto variance =
info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_sqdiff_mean);
epsilon =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}});
auto var_eps = info.add_common_op("add", variance, eps);
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto result = info.add_common_op("mul", x_sub_mean, rsqrt);
instruction_ref scale_bcast = scale;
instruction_ref bias_bcast = bias;
if(skipped_axes > 0)
{
auto x_dims = x_shape.lens();
scale_bcast = info.add_instruction(
make_op("broadcast", {{"axis", skipped_axes}, {"out_lens", x_dims}}), scale);
if(not skip_bias)
{
bias_bcast = info.add_instruction(
make_op("broadcast", {{"axis", skipped_axes}, {"out_lens", x_dims}}), bias);
}
}
auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast);
auto y = skip_bias ? scaled : info.add_instruction(make_op("add"), scaled, bias_bcast);
return {y, mean, rsqrt};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_mean_variance_normalization : op_parser<parse_mean_variance_normalization>
{
std::vector<op_desc> operators() const { return {{"MeanVarianceNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto&& data = args.front();
auto data_rank = data->get_shape().ndim();
std::vector<int64_t> axes{0, 2, 3};
if(contains(info.attributes, "axes"))
{
const auto& axes_attr = info.attributes["axes"].ints();
axes.assign(axes_attr.begin(), axes_attr.end());
}
else if(data_rank != 4)
{
MIGRAPHX_THROW(
"Input tensor needs to be rank 4 when axes is not specified. Instead it is rank " +
std::to_string(data_rank));
}
if(axes.size() != data_rank - 1)
{
MIGRAPHX_THROW("Length of axes array needs to be equal to input tensor rank - 1");
}
auto data_mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), data);
auto data_mean_squared = info.add_common_op("mul", data_mean, data_mean);
auto data_squared = info.add_common_op("mul", data, data);
auto data_squared_mean =
info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), data_squared);
auto mean_sub = info.add_common_op("sub", data_squared_mean, data_mean_squared);
auto std = info.add_common_op("sqrt", mean_sub);
auto dividend = info.add_common_op("sub", data, data_mean);
auto epsilon =
info.add_literal({data->get_shape().type(),
{data->get_shape().type() == shape::half_type ? 1e-7 : 1e-9}});
auto divisor = info.add_common_op("add", std, epsilon);
return info.add_common_op("div", dividend, divisor);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -115,34 +115,9 @@ struct parse_pad : op_parser<parse_pad> ...@@ -115,34 +115,9 @@ struct parse_pad : op_parser<parse_pad>
{ {
std::vector<op_desc> operators() const { return {{"Pad"}}; } std::vector<op_desc> operators() const { return {{"Pad"}}; }
instruction_ref parse(const op_desc& /*opd*/, std::string parse_mode(const onnx_parser::node_info& info,
const onnx_parser& parser, const std::vector<instruction_ref>& args) const
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{ {
std::vector<int64_t> pads{};
if(args.size() >= 2)
{
auto pad_arg = args.at(1)->eval();
check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant");
pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
}
else if(contains(info.attributes, "pads"))
{
auto&& pad_vals = info.attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
else
{
MIGRAPHX_THROW("PARSE_PAD: pad must be available");
}
// check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{
return info.add_instruction(make_op("identity"), args.front());
}
if(contains(info.attributes, "mode")) if(contains(info.attributes, "mode"))
{ {
auto mode = info.attributes.at("mode").s(); auto mode = info.attributes.at("mode").s();
...@@ -152,28 +127,59 @@ struct parse_pad : op_parser<parse_pad> ...@@ -152,28 +127,59 @@ struct parse_pad : op_parser<parse_pad>
{ {
MIGRAPHX_THROW("PARSE_PAD: reflect padding with dynamic shape not supported"); MIGRAPHX_THROW("PARSE_PAD: reflect padding with dynamic shape not supported");
} }
return reflect_pad(info, pads, args.front());
} }
if(mode != "constant") else if(mode != "constant")
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW(
"PARSE_PAD: migraphx currently only supports constant and reflect padding"); "PARSE_PAD: migraphx currently only supports constant and reflect padding");
} }
return mode;
}
else
{
// default mode
return "constant";
} }
}
std::vector<int64_t> parse_pads(const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
std::vector<int64_t> pads{};
if(args.size() >= 2)
{
auto pad_arg = args.at(1)->eval();
check_arg_empty(pad_arg, "PARSE_PAD: `pads` input must be constant");
pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
}
else if(contains(info.attributes, "pads"))
{
auto&& pad_vals = info.attributes.at("pads").ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
else
{
MIGRAPHX_THROW("PARSE_PAD: `pads` must be available");
}
return pads;
}
float parse_constant_value(const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
float value = 0.0f; float value = 0.0f;
// third input is the value if(args.size() >= 3 and args.at(2)->get_shape().scalar())
if(args.size() == 3)
{ {
auto val_ins = args.at(2); auto val_ins = args.at(2);
if(not val_ins->can_eval()) if(not val_ins->can_eval())
{ {
MIGRAPHX_THROW("PARSE_PAD: input value must be constant"); MIGRAPHX_THROW("PARSE_PAD: input `value` must be constant");
} }
auto val_arg = val_ins->eval(); auto val_arg = val_ins->eval();
if(val_arg.get_shape().elements() != 1) if(val_arg.get_shape().elements() != 1)
{ {
MIGRAPHX_THROW("PARSE_PAD: value should contain only one element"); MIGRAPHX_THROW("PARSE_PAD: `value` should contain only one element");
} }
value = val_arg.at<float>(); value = val_arg.at<float>();
} }
...@@ -181,6 +187,81 @@ struct parse_pad : op_parser<parse_pad> ...@@ -181,6 +187,81 @@ struct parse_pad : op_parser<parse_pad>
{ {
value = parser.parse_value(info.attributes.at("value")).at<float>(); value = parser.parse_value(info.attributes.at("value")).at<float>();
} }
return value;
}
std::vector<int64_t> parse_axes(const std::vector<instruction_ref>& args,
bool is_constant_mode) const
{
std::vector<int64_t> axes{};
// axes is 3rd or 4th, depending on constant mode
auto pos = is_constant_mode ? 4 : 3;
if(args.size() >= pos)
{
auto axes_arg = args.at(pos - 1)->eval();
check_arg_empty(axes_arg, "PARSE_PAD: variable `axes` input not supported");
axes_arg.visit([&](auto v) { axes.assign(v.begin(), v.end()); });
}
return axes;
}
std::vector<int64_t> calculate_pads_with_axes(const std::vector<int64_t>& pads,
const std::vector<int64_t>& axes,
size_t input_rank) const
{
size_t num_axes = axes.size();
if(num_axes * 2 != pads.size())
{
MIGRAPHX_THROW("PARSE_PAD: number of elements of pads should be equal to 2 * "
"number of elements of axes");
}
std::vector<int64_t> new_pads(input_rank * 2);
for(size_t idx{0}; idx < num_axes; ++idx)
{
// axis can be negative
int64_t axis = axes[idx] < 0 ? input_rank + axes[idx] : axes[idx];
// pad format is x1_begin, x2_begin, ... , x3_end, x4_end
new_pads[axis] = pads[idx];
new_pads[axis + input_rank] = pads[idx + num_axes];
}
return new_pads;
}
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
std::vector<int64_t> pads = parse_pads(info, args);
// check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{
return info.add_instruction(make_op("identity"), args.front());
}
std::string mode = parse_mode(info, args);
bool is_constant_mode = mode == "constant";
float value = is_constant_mode ? parse_constant_value(parser, info, args) : 0.0f;
std::vector<int64_t> axes = parse_axes(args, is_constant_mode);
size_t input_rank = args.front()->get_shape().ndim();
if(not axes.empty())
{
pads = calculate_pads_with_axes(pads, axes, input_rank);
}
if(pads.size() != input_rank * 2)
{
MIGRAPHX_THROW("PARSE_PAD: number of elements of pads should be equal to 2 * "
"input rank");
}
if(mode == "reflect")
{
return reflect_pad(info, pads, args.front());
}
return info.add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", value}}), return info.add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", value}}),
args.front()); args.front());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment