"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "95a28019ba6c7288c1d2e747665d6a9dd005fdc2"
Unverified Commit a24ed87e authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into optimize_jenkinsfile

parents 6481cd69 a09dc502
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/optional.hpp>
#include <vector> #include <vector>
namespace migraphx { namespace migraphx {
...@@ -68,6 +69,19 @@ auto stream_write_value_impl(rank<1>, std::ostream& os, const T& x) -> decltype( ...@@ -68,6 +69,19 @@ auto stream_write_value_impl(rank<1>, std::ostream& os, const T& x) -> decltype(
os << x; os << x;
} }
template <class T>
auto stream_write_value_impl(rank<1>, std::ostream& os, const optional<T>& x)
{
if(x.has_value())
{
os << *x;
}
else
{
os << "nullopt";
}
}
template <class T> template <class T>
void stream_write_value_impl(rank<1>, std::ostream& os, const std::vector<T>& r) void stream_write_value_impl(rank<1>, std::ostream& os, const std::vector<T>& r)
{ {
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -24,21 +24,21 @@ ...@@ -24,21 +24,21 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP #define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#include <utility>
#include <cstdint>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
inline int tune_axis(const int n_dim, const int axis, const std::string& op_name = "OPERATOR") inline int tune_axis(int n_dim, int axis, const std::string& op_name = "OPERATOR")
{ {
if(axis >= n_dim or std::abs(axis) > n_dim) if(axis < 0)
{ axis += n_dim;
if(axis < 0 or axis >= n_dim)
MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range."); MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range.");
}
return (axis < 0) ? axis + n_dim : axis; return axis;
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -28,25 +28,35 @@ ...@@ -28,25 +28,35 @@
#include <type_traits> #include <type_traits>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_DETAIL_DEFINE_TRAIT(trait) \
template <class X> \
struct trait : std::trait<X> \
{ \
};
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ #define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \ template <> \
struct trait<T> : std::true_type \ struct trait<T> : std::true_type \
{ \ { \
}; };
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_arithmetic);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_signed);
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz)
template <class T> template <class T>
using accumulator_type = using accumulator_type =
std::conditional_t<is_floating_point<T>{}, std::conditional_t<is_floating_point<T>{},
......
...@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{ {
if(input_shape.dynamic()) if(input_shape.dynamic())
{ {
// return the unchanged `vec` if the dynamic_dimensions at `axes` are not fixed
if(std::any_of(axes.begin(), axes.end(), [&](auto ax) {
return not input_shape.dyn_dims().at(ax).is_fixed();
}))
{
return vec;
}
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
const auto& dd = input_shape.dyn_dims().at(i); return input_shape.dyn_dims().at(i).max;
if(not dd.is_fixed())
{
MIGRAPHX_THROW(
"NORMALIZE_ATTR: 'use_lens' on a non-fixed dynamic dimension, axis=" +
std::to_string(i));
}
return dd.max;
}); });
} }
else else
......
...@@ -26,7 +26,11 @@ find_package(Protobuf REQUIRED) ...@@ -26,7 +26,11 @@ find_package(Protobuf REQUIRED)
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS onnx.proto) protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS onnx.proto)
add_library(onnx-proto STATIC ${PROTO_SRCS}) add_library(onnx-proto STATIC ${PROTO_SRCS})
target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR}) target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(onnx-proto PRIVATE -w) if(MSVC)
target_compile_options(onnx-proto PRIVATE /w)
else()
target_compile_options(onnx-proto PRIVATE -w)
endif()
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
...@@ -37,7 +41,10 @@ set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx) ...@@ -37,7 +41,10 @@ set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
migraphx_generate_export_header(migraphx_onnx) migraphx_generate_export_header(migraphx_onnx)
rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_onnx) rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto "-Wl,--exclude-libs,ALL") target_link_libraries(migraphx_onnx PRIVATE onnx-proto)
if(NOT WIN32)
target_link_libraries(migraphx_onnx PRIVATE "-Wl,--exclude-libs,ALL")
endif()
target_link_libraries(migraphx_onnx PUBLIC migraphx) target_link_libraries(migraphx_onnx PUBLIC migraphx)
rocm_install_targets( rocm_install_targets(
......
...@@ -97,10 +97,11 @@ struct onnx_parser ...@@ -97,10 +97,11 @@ struct onnx_parser
shape::dynamic_dimension default_dyn_dim_value = {1, 1}; shape::dynamic_dimension default_dyn_dim_value = {1, 1};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false; bool use_dyn_output = false;
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10; int64_t max_loop_iterations = 10;
int64_t opset_version = 13; int64_t limit_max_iterations = std::numeric_limits<uint16_t>::max();
int64_t opset_version = 13;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,27 +21,26 @@ ...@@ -21,27 +21,26 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP #include <migraphx/config.hpp>
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP #include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/argument.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace onnx {
namespace device {
value handle_pooling_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values);
argument MIGRAPHX_DEVICE_EXPORT pad(hipStream_t stream, instruction_ref add_pooling_op(const op_desc& opd, onnx_parser::node_info info, instruction_ref l0);
argument result,
argument arg1,
float value,
std::vector<std::int64_t> pads);
} // namespace device } // namespace onnx
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -67,6 +67,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -67,6 +67,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
} }
parser.skip_unknown_operators = options.skip_unknown_operators; parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations; parser.max_loop_iterations = options.max_loop_iterations;
parser.limit_max_iterations = options.limit_max_iterations;
parser.use_dyn_output = options.use_dyn_output; parser.use_dyn_output = options.use_dyn_output;
if(options.print_program_on_error) if(options.print_program_on_error)
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
// //
// Copyright (c) ONNX Project Contributors. // SPDX-License-Identifier: Apache-2.0
// Licensed under the MIT license.
syntax = "proto2"; syntax = "proto2";
...@@ -20,23 +20,16 @@ package onnx_for_migraphx; ...@@ -20,23 +20,16 @@ package onnx_for_migraphx;
// //
// This document describes the syntax of models and their computation graphs, // This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX // as well as the standard data types. Together, they are referred to as the ONNX
// Intermediate Representation, or 'IR' for short. // Intermediate Representation, or 'IR' for short.
// //
// The normative semantic specification of the ONNX IR is found in docs/IR.md. // The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md. // Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes // Notes
// //
// Release
//
// We are still in the very early stage of defining ONNX. The current
// version of ONNX is a starting point. While we are actively working
// towards a complete spec, we would like to get the community involved
// by sharing our working version of ONNX.
//
// Protobuf compatibility // Protobuf compatibility
// //
// To simplify framework compatibility, ONNX is defined using the subset of protobuf // To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any // that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions. // protobuf features that are only available in one of the two versions.
// //
...@@ -60,7 +53,7 @@ enum Version { ...@@ -60,7 +53,7 @@ enum Version {
_START_VERSION = 0; _START_VERSION = 0;
// The version field is always serialized and we will use it to store the // The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version // version that the graph is generated from. This helps us set up version
// control. // control.
// For the IR, we are using simple numbers starting with 0x00000001, // For the IR, we are using simple numbers starting with 0x00000001,
// which was the version we published on Oct 10, 2017. // which was the version we published on Oct 10, 2017.
IR_VERSION_2017_10_10 = 0x0000000000000001; IR_VERSION_2017_10_10 = 0x0000000000000001;
...@@ -92,15 +85,28 @@ enum Version { ...@@ -92,15 +85,28 @@ enum Version {
// - Add sparse initializers // - Add sparse initializers
IR_VERSION_2019_9_19 = 0x0000000000000006; IR_VERSION_2019_9_19 = 0x0000000000000006;
// IR VERSION 7 published on <TBD> // IR VERSION 7 published on May 8, 2020
// - Add support to allow function body graph to rely on multiple external opreator sets.
// - Add a list to promote inference graph's initializers to global and // - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the // mutable variables. Global variables are visible in all graphs of the
// stored models. // stored models.
// - Add message TrainingInfoProto to store initialization // - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto // method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables. // can modify the values of mutable variables.
// - Make inference graph callable from TrainingInfoProto via GraphCall operator. // - Implicitly add inference graph into each TrainingInfoProto's algorithm.
IR_VERSION = 0x0000000000000007; IR_VERSION_2020_5_8 = 0x0000000000000007;
// IR VERSION 8 published on July 30, 2021
// Introduce TypeProto.SparseTensor
// Introduce TypeProto.Optional
// Added a list of FunctionProtos local to the model
// Deprecated since_version and operator status from FunctionProto
IR_VERSION_2021_7_30 = 0x0000000000000008;
// IR VERSION 9 published on TBD
// Added AttributeProto to FunctionProto so that default attribute values can be set.
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
IR_VERSION = 0x0000000000000009;
} }
// Attributes // Attributes
...@@ -121,6 +127,7 @@ message AttributeProto { ...@@ -121,6 +127,7 @@ message AttributeProto {
TENSOR = 4; TENSOR = 4;
GRAPH = 5; GRAPH = 5;
SPARSE_TENSOR = 11; SPARSE_TENSOR = 11;
TYPE_PROTO = 13;
FLOATS = 6; FLOATS = 6;
INTS = 7; INTS = 7;
...@@ -128,11 +135,12 @@ message AttributeProto { ...@@ -128,11 +135,12 @@ message AttributeProto {
TENSORS = 9; TENSORS = 9;
GRAPHS = 10; GRAPHS = 10;
SPARSE_TENSORS = 12; SPARSE_TENSORS = 12;
TYPE_PROTOS = 14;
} }
// The name field MUST be present for this version of the IR. // The name field MUST be present for this version of the IR.
optional string name = 1; // namespace Attribute optional string name = 1; // namespace Attribute
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute // In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope. // in parent scope.
...@@ -159,6 +167,7 @@ message AttributeProto { ...@@ -159,6 +167,7 @@ message AttributeProto {
optional SparseTensorProto sparse_tensor = 22; // sparse tensor value optional SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated. // Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph // optional ValueProto v = 12; // value - subsumes everything but graph
optional TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints repeated int64 ints = 8; // list of ints
...@@ -166,6 +175,7 @@ message AttributeProto { ...@@ -166,6 +175,7 @@ message AttributeProto {
repeated TensorProto tensors = 10; // list of tensors repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph repeated GraphProto graphs = 11; // list of graph
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
repeated TypeProto type_protos = 15;// list of type protos
} }
// Defines information on value, including the name, the type, and // Defines information on value, including the name, the type, and
...@@ -185,7 +195,7 @@ message ValueInfoProto { ...@@ -185,7 +195,7 @@ message ValueInfoProto {
// Computation graphs are made up of a DAG of nodes, which represent what is // Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks. // commonly called a "layer" or "pipeline stage" in machine learning frameworks.
// //
// For example, it can be a node of type "Conv" that takes in an image, a filter // For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output. // tensor and a bias tensor, and produces the convolved output.
message NodeProto { message NodeProto {
repeated string input = 1; // namespace Value repeated string input = 1; // namespace Value
...@@ -211,7 +221,7 @@ message NodeProto { ...@@ -211,7 +221,7 @@ message NodeProto {
// TrainingInfoProto stores information for training a model. // TrainingInfoProto stores information for training a model.
// In particular, this defines two functionalities: an initialization-step // In particular, this defines two functionalities: an initialization-step
// and a training-algorithm-step. Initialization resets the model // and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been consumed. // back to its original state as if no training has been performed.
// Training algorithm improves the model based on input data. // Training algorithm improves the model based on input data.
// //
// The semantics of the initialization-step is that the initializers // The semantics of the initialization-step is that the initializers
...@@ -224,8 +234,8 @@ message NodeProto { ...@@ -224,8 +234,8 @@ message NodeProto {
// training algorithm's step. After the execution of a // training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding" // TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains // may be immediately updated. If the targeted training algorithm contains
// consecutive update stages (such as block coordinate descent methods), // consecutive update steps (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each stage. // the user needs to create a TrainingInfoProto for each step.
message TrainingInfoProto { message TrainingInfoProto {
// This field describes a graph to compute the initial tensors // This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input // upon starting the training process. Initialization graph has no input
...@@ -239,24 +249,42 @@ message TrainingInfoProto { ...@@ -239,24 +249,42 @@ message TrainingInfoProto {
// iteration to zero. // iteration to zero.
// //
// By default, this field is an empty graph and its evaluation does not // By default, this field is an empty graph and its evaluation does not
// produce any output. // produce any output. Thus, no initializer would be changed by default.
optional GraphProto initialization = 1; optional GraphProto initialization = 1;
// This field represents a training algorithm step. Given required inputs, // This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's // it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this graph contains loss node, gradient node, // initializer lists. In general, this field contains loss node, gradient node,
// optimizer node, increment of iteration count, and some calls to the inference // optimizer node, increment of iteration count.
// graph.
// //
// The field algorithm.node is the only place the user can use GraphCall // An execution of the training algorithm step is performed by executing the
// operator. The only callable graph is the one stored in ModelProto.graph. // graph obtained by combining the inference graph (namely "ModelProto.graph")
// and the "algorithm" graph. That is, the actual the actual
// input/initializer/output/node/value_info/sparse_initializer list of
// the training graph is the concatenation of
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
// in that order. This combined graph must satisfy the normal ONNX conditions.
// Now, let's provide a visualization of graph combination for clarity.
// Let the inference graph (i.e., "ModelProto.graph") be
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
// and the "algorithm" graph be
// tensor_d -> Add -> tensor_e
// The combination process results
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
//
// Notice that an input of a node in the "algorithm" graph may reference the
// output of a node in the inference graph (but not the other way round). Also, inference
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
// can always be run independently without training information.
// //
// By default, this field is an empty graph and its evaluation does not // By default, this field is an empty graph and its evaluation does not
// produce any output. // produce any output. Evaluating the default training step never
// update any initializers.
optional GraphProto algorithm = 2; optional GraphProto algorithm = 2;
// This field specifies the bindings from the outputs of "initialization" to // This field specifies the bindings from the outputs of "initialization" to
// some initializers in "ModelProto.graph.initializer" and // some initializers in "ModelProto.graph.initializer" and
// the "algorithm.initializer" in the same TrainingInfoProto. // the "algorithm.initializer" in the same TrainingInfoProto.
// See "update_binding" below for details. // See "update_binding" below for details.
// //
...@@ -284,23 +312,16 @@ message TrainingInfoProto { ...@@ -284,23 +312,16 @@ message TrainingInfoProto {
// be multiple key-value pairs in "update_binding". // be multiple key-value pairs in "update_binding".
// //
// The initializers appears as keys in "update_binding" are considered // The initializers appears as keys in "update_binding" are considered
// mutable and globally-visible variables. This implies some behaviors // mutable variables. This implies some behaviors
// as described below. // as described below.
// //
// 1. We have only unique keys in all "update_binding"s so that two global // 1. We have only unique keys in all "update_binding"s so that two
// variables may not have the same name. This ensures that one // variables may not have the same name. This ensures that one
// global variable is assigned up to once. // variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or // 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer". // "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm". // 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
// 4. If an optional input of a graph is omitted when using GraphCall, the // 4. Mutable variables are initialized to the value specified by the
// global variable with the same name may be used.
// 5. When using GraphCall, the users always can pass values to optional
// inputs of the called graph even if the associated initializers appears
// as keys in "update_binding"s.
// 6. The graphs in TrainingInfoProto's can use global variables as
// their operator inputs.
// 7. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by // corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
// //
...@@ -375,13 +396,31 @@ message ModelProto { ...@@ -375,13 +396,31 @@ message ModelProto {
// //
// If this field is empty, the training behavior of the model is undefined. // If this field is empty, the training behavior of the model is undefined.
repeated TrainingInfoProto training_info = 20; repeated TrainingInfoProto training_info = 20;
// A list of function protos local to the model.
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard opserator sets are given higher priotity or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto and other model local FunctionProtos.
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
// or by 2 FunctionProtos then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same for every node in the function body.
//
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated FunctionProto functions = 25;
}; };
// StringStringEntryProto follows the pattern for cross-proto-version maps. // StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps // See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto { message StringStringEntryProto {
optional string key = 1; optional string key = 1;
optional string value= 2; optional string value = 2;
}; };
message TensorAnnotation { message TensorAnnotation {
...@@ -397,7 +436,7 @@ message TensorAnnotation { ...@@ -397,7 +436,7 @@ message TensorAnnotation {
// Graphs // Graphs
// //
// A graph defines the computational logic of a model and is comprised of a parameterized // A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs. // list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning // This is the equivalent of the "network" or "graph" in many deep learning
// frameworks. // frameworks.
...@@ -409,8 +448,9 @@ message GraphProto { ...@@ -409,8 +448,9 @@ message GraphProto {
optional string name = 2; // namespace Graph optional string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph. // A list of named tensor values, used to specify constant inputs of the graph.
// Each TensorProto entry must have a distinct name (within the list) that // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
// MAY also appear in the input list. // The name MUST be unique across both initializer and sparse_initializer,
// but the name MAY also appear in the input list.
repeated TensorProto initializer = 5; repeated TensorProto initializer = 5;
// Initializers (see above) stored in sparse format. // Initializers (see above) stored in sparse format.
...@@ -433,13 +473,8 @@ message GraphProto { ...@@ -433,13 +473,8 @@ message GraphProto {
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14; repeated TensorAnnotation quantization_annotation = 14;
// DO NOT USE the following fields, they were deprecated from earlier versions. reserved 3, 4, 6 to 9;
// repeated string input = 3; reserved "ir_version", "producer_version", "producer_tag", "domain";
// repeated string output = 4;
// optional int64 ir_version = 6;
// optional int64 producer_version = 7;
// optional string producer_tag = 8;
// optional string domain = 9;
} }
// Tensors // Tensors
...@@ -474,6 +509,17 @@ message TensorProto { ...@@ -474,6 +509,17 @@ message TensorProto {
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16; BFLOAT16 = 16;
// Non-IEEE floating-point format based on papers
// FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
// 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
// The computation usually happens inside a block quantize / dequantize
// fused by the runtime.
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
// Future extensions go here. // Future extensions go here.
} }
...@@ -507,11 +553,11 @@ message TensorProto { ...@@ -507,11 +553,11 @@ message TensorProto {
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64. // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true]; repeated float float_data = 4 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, and float16 values // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
// float16 values must be bit-wise converted to an uint16_t prior // float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer. // to writing to the buffer.
// When this field is present, the data_type field MUST be // When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated int32 int32_data = 5 [packed = true]; repeated int32 int32_data = 5 [packed = true];
// For strings. // For strings.
...@@ -589,6 +635,8 @@ message TensorProto { ...@@ -589,6 +635,8 @@ message TensorProto {
message SparseTensorProto { message SparseTensorProto {
// The sequence of non-default values are encoded as a tensor of shape [NNZ]. // The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors. // The default-value is zero for numeric tensors, and empty-string for string tensors.
// values must have a non-empty name present which serves as a name for SparseTensorProto
// when used in sparse_initializer list.
optional TensorProto values = 1; optional TensorProto values = 1;
// The indices of the non-default values, which may be stored in one of two formats. // The indices of the non-default values, which may be stored in one of two formats.
...@@ -619,7 +667,7 @@ message TensorShapeProto { ...@@ -619,7 +667,7 @@ message TensorShapeProto {
// Standard denotation can optionally be used to denote tensor // Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure // dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor. // that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations. // for pre-defined dimension denotations.
optional string denotation = 3; optional string denotation = 3;
}; };
...@@ -656,6 +704,23 @@ message TypeProto { ...@@ -656,6 +704,23 @@ message TypeProto {
optional TypeProto value_type = 2; optional TypeProto value_type = 2;
}; };
// wrapper for Tensor, Sequence, or Map
message Optional {
// The type and optional shape of the element wrapped.
// This field MUST be present for this version of the IR.
// Possible values correspond to OptionalProto.DataType enum
optional TypeProto elem_type = 1;
};
message SparseTensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
optional int32 elem_type = 1;
optional TensorShapeProto shape = 2;
}
oneof value { oneof value {
// The type of a tensor. // The type of a tensor.
...@@ -672,11 +737,18 @@ message TypeProto { ...@@ -672,11 +737,18 @@ message TypeProto {
// The type of a map. // The type of a map.
Map map_type = 5; Map map_type = 5;
// The type of an optional.
Optional optional_type = 9;
// Type of the sparse tensor
SparseTensor sparse_tensor_type = 8;
} }
// An optional denotation can be used to denote the whole // An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is // type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations. // for pre-defined type denotations.
optional string denotation = 6; optional string denotation = 6;
} }
...@@ -696,7 +768,67 @@ message OperatorSetIdProto { ...@@ -696,7 +768,67 @@ message OperatorSetIdProto {
optional int64 version = 2; optional int64 version = 2;
} }
// Operator/function status.
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}
message FunctionProto {
// The name of the function, similar usage of op_type in OperatorProto.
// Combined with FunctionProto.domain, this forms the unique identity of
// the FunctionProto.
optional string name = 1;
// Deprecated since IR Version 8
// optional int64 since_version = 2;
reserved 2;
reserved "since_version";
// Deprecated since IR Version 8
// optional OperatorStatus status = 3;
reserved 3;
reserved "status";
// The inputs and outputs of the function.
repeated string input = 4;
repeated string output = 5;
// The attribute parameters of the function.
// It is for function parameters without default values.
repeated string attribute = 6;
// The attribute protos of the function.
// It is for function attributes with default values.
// A function attribute shall be represented either as
// a string attribute or an AttributeProto, not both.
repeated AttributeProto attribute_proto = 11;
// The nodes in the function.
repeated NodeProto node = 7;
// A human-readable documentation for this function. Markdown is allowed.
optional string doc_string = 8;
// The OperatorSets this function body (graph) relies on.
//
// All nodes in the function body (graph) will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets. This means at most one version can be relied
// for one domain.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
// and ModelProto then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same.
// For using protobuf-lite repeated OperatorSetIdProto opset_import = 9;
option optimize_for = LITE_RUNTIME;
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
optional string domain = 10;
}
// For using protobuf-lite
option optimize_for = LITE_RUNTIME;
\ No newline at end of file
...@@ -34,7 +34,9 @@ ...@@ -34,7 +34,9 @@
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <onnx.pb.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const ...@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
case onnx::AttributeProto::TENSORS: case onnx::AttributeProto::TENSORS:
case onnx::AttributeProto::SPARSE_TENSOR: case onnx::AttributeProto::SPARSE_TENSOR:
case onnx::AttributeProto::SPARSE_TENSORS: case onnx::AttributeProto::SPARSE_TENSORS:
case onnx::AttributeProto::TYPE_PROTOS:
case onnx::AttributeProto::TYPE_PROTO:
case onnx::AttributeProto::GRAPHS: return {}; case onnx::AttributeProto::GRAPHS: return {};
} }
MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type())); MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type()));
...@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const ...@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data()); return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data()); case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::FLOAT8E4M3FNUZ: {
std::vector<int32_t> data_int32(t.int32_data().begin(), t.int32_data().end());
std::vector<migraphx::fp8::fp8e4m3fnuz> data_fp8;
std::transform(data_int32.begin(),
data_int32.end(),
std::back_inserter(data_fp8),
[](float raw_val) { return migraphx::fp8::fp8e4m3fnuz{raw_val}; });
return create_literal(shape::fp8e4m3fnuz_type, dims, data_fp8);
}
case onnx::TensorProto::FLOAT8E5M2FNUZ:
case onnx::TensorProto::FLOAT8E5M2:
case onnx::TensorProto::FLOAT8E4M3FN:
case onnx::TensorProto::UNDEFINED: case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::STRING: case onnx::TensorProto::STRING:
case onnx::TensorProto::COMPLEX64: case onnx::TensorProto::COMPLEX64:
...@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype) ...@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type; case 11: return shape::double_type;
case 12: return shape::uint32_type; case 12: return shape::uint32_type;
case 13: return shape::uint64_type; case 13: return shape::uint64_type;
case 18: return shape::fp8e4m3fnuz_type;
case 14:
case 15:
case 16:
case 17:
case 19:
case 20:
default: { default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported"); MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
...@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Neg", "neg"}, {"Neg", "neg"},
{"Reciprocal", "recip"}, {"Reciprocal", "recip"},
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "nearbyint"},
{"Sigmoid", "sigmoid"}, {"Sigmoid", "sigmoid"},
{"Sign", "sign"}, {"Sign", "sign"},
{"Sin", "sin"}, {"Sin", "sin"},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_isinf : op_parser<parse_isinf>
{
std::vector<op_desc> operators() const { return {{"IsInf", "isinf"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
bool detect_negative = true;
bool detect_positive = true;
if(contains(info.attributes, "detect_negative"))
{
detect_negative = static_cast<bool>(
parser.parse_value(info.attributes.at("detect_negative")).at<int>());
}
if(contains(info.attributes, "detect_positive"))
{
detect_positive = static_cast<bool>(
parser.parse_value(info.attributes.at("detect_positive")).at<int>());
}
auto x_shape = args[0]->get_shape();
if(not detect_negative and not detect_positive)
{
return info.add_instruction(
make_op("multibroadcast", {{"out_lens", x_shape.lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{shape::bool_type}, {false}}));
}
auto is_inf = info.add_instruction(make_op("isinf"), args[0]);
if(detect_negative and detect_positive)
{
return is_inf;
}
auto zero_l = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}});
auto mb_zero =
info.add_instruction(make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), zero_l);
auto cond = info.add_broadcastable_binary_op(
detect_negative ? "less" : "greater", args[0], mb_zero);
if(cond->get_shape().type() != shape::bool_type)
{
cond =
info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), cond);
}
return info.add_instruction(make_op("logical_and"), is_inf, cond);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -58,6 +58,16 @@ struct parse_loop : op_parser<parse_loop> ...@@ -58,6 +58,16 @@ struct parse_loop : op_parser<parse_loop>
} }
} }
// cap max_iter because loop uses static shapes with max_iter size and huge numbers
// here can cause overflow
if(max_iterations > parser.limit_max_iterations)
{
std::cerr << "WARNING: PARSE_LOOP max_iterations exceeds the maximum loop "
"iterations limit, it will be changed from "
<< max_iterations << " to " << parser.limit_max_iterations << ".\n";
max_iterations = parser.limit_max_iterations;
}
// condition input is empty // condition input is empty
if(args.at(1)->name() == "undefined") if(args.at(1)->name() == "undefined")
{ {
......
...@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv ...@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
} }
} }
void lstm_transpose_inputs(onnx_parser::node_info& info, std::vector<instruction_ref>& args)
{
std::vector<int64_t> perm{1, 0, 2};
args[0] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]);
if(args.size() >= 6 and not args[5]->is_undefined())
{
args[5] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[5]);
}
if(args.size() >= 7 and not args[6]->is_undefined())
{
args[6] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[6]);
}
}
void lstm_transpose_outputs(onnx_parser::node_info& info,
instruction_ref& hidden_states,
instruction_ref& last_output,
instruction_ref& last_cell_output)
{
std::vector<int64_t> perm_hs{2, 0, 1, 3};
hidden_states =
info.add_instruction(make_op("transpose", {{"permutation", perm_hs}}), hidden_states);
std::vector<int64_t> perm_last{1, 0, 2};
last_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_output);
last_cell_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_cell_output);
}
struct parse_lstm : op_parser<parse_lstm> struct parse_lstm : op_parser<parse_lstm>
{ {
std::vector<op_desc> operators() const { return {{"LSTM"}}; } std::vector<op_desc> operators() const { return {{"LSTM"}}; }
...@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>(); input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>();
} }
int layout = 0;
if(contains(info.attributes, "layout"))
{
layout = parser.parse_value(info.attributes.at("layout")).at<int>();
}
// append undefined opeator to make 6 arguments // append undefined opeator to make 6 arguments
if(args.size() < 8) if(args.size() < 8)
{ {
...@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args.insert(args.end(), 8 - args.size(), ins); args.insert(args.end(), 8 - args.size(), ins);
} }
if(layout != 0)
{
lstm_transpose_inputs(info, args);
}
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("lstm", auto hidden_states = info.add_instruction(make_op("lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto last_cell_output = auto last_cell_output =
info.add_instruction(make_op("rnn_last_cell_output"), hidden_states); info.add_instruction(make_op("rnn_last_cell_output"), hidden_states);
if(layout != 0)
{
lstm_transpose_outputs(info, hidden_states, last_output, last_cell_output);
}
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -41,6 +41,9 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -41,6 +41,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
if(args.empty())
MIGRAPHX_THROW("PARSE_MULTINOMIAL: no arguments given");
int dtype = 6; int dtype = 6;
if(contains(info.attributes, "dtype")) if(contains(info.attributes, "dtype"))
dtype = info.attributes.at("dtype").i(); dtype = info.attributes.at("dtype").i();
...@@ -49,35 +52,90 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -49,35 +52,90 @@ struct parse_multinomial : op_parser<parse_multinomial>
size_t sample_size = 1; size_t sample_size = 1;
if(contains(info.attributes, "sample_size")) if(contains(info.attributes, "sample_size"))
sample_size = info.attributes.at("sample_size").i(); sample_size = info.attributes.at("sample_size").i();
else
MIGRAPHX_THROW("PARSE_MULTINOMIAL: sample_size not given");
// Use logarithmic math to scale probabilities while avoiding division by very
// small numbers. Scaling by the maximum makes very tiny ranges more
// tractable; any constant factor gives equivalent distr. since the Multinomial op.
// normalizes at runtime.
// Subtract the per-batch maximum log-probability, making the per-batch max 0 // Subtract the per-batch maximum log-probability, making the per-batch max 0
auto maxes = auto maxes =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]); info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]);
auto mb_maxes = info.add_instruction( auto cdf = info.add_common_op("sub", args[0], maxes);
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
maxes);
auto cdf = info.add_instruction(migraphx::make_op("sub"), args[0], mb_maxes);
// Take the element-wise exponent to get probabilities in the range (0, 1] // Take the element-wise exponent to get probabilities in the range (0, 1]
cdf = info.add_instruction(migraphx::make_op("exp"), cdf); cdf = info.add_instruction(migraphx::make_op("exp"), cdf);
// Compute the cumulative density function // Compute the cumulative distribution function
cdf = info.add_instruction( cdf = info.add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
// Pre-compute random distribution instruction_ref seed_input;
std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed")) if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f()); {
float seed = info.attributes.at("seed").f();
migraphx::shape s{migraphx::shape::float_type, {1}};
std::vector<float> data = {seed};
seed_input = info.add_literal(migraphx::literal(s, data));
}
else
{
seed_input = info.add_instruction(migraphx::make_op("random_seed"));
}
instruction_ref randoms;
shape s0 = args[0]->get_shape();
if(s0.dynamic())
{
// Dynamic batch_size will be taken from args[0]. The input argument to this should
// have a second dimension of sample_size.
std::vector<shape::dynamic_dimension> dyn_dim_set;
dyn_dim_set.emplace_back(s0.dyn_dims().front());
dyn_dim_set.emplace_back(shape::dynamic_dimension{sample_size, sample_size});
// read the input dimensions
auto dim_of =
info.add_instruction(migraphx::make_op("dimensions_of", {{"end", 2}}), args[0]);
// The next two operations insert the value sample_size into the second array position
// make an argument of (1, 0)
shape s(shape::int64_type, {2});
std::vector<int64_t> data1{1, 0};
auto l1 = info.add_literal(s, data1);
auto batch_arg = info.add_instruction(migraphx::make_op("mul"), dim_of, l1);
std::vector<int64_t> data2(2, 0);
// make an argument of (0, sample_size)
data2[1] = sample_size;
auto l2 = info.add_literal(s, data2);
auto alloc_shape = info.add_instruction(migraphx::make_op("add"), batch_arg, l2);
// alloc_shape should contain the input-based shape dimensions as its values at runtime,
// and its own shape is {2}
std::uniform_real_distribution<> dis(0.0, 1.0); // compile_shape is the shape used when compiling the Allocate op, and may be dynamic
size_t batch_size = args[0]->get_shape().lens().front(); migraphx::shape compile_shape =
migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}}; migraphx::shape(s0.type(), {s0.dyn_dims().front(), {sample_size, sample_size}});
std::vector<float> random_dist(batch_size * sample_size); // Allocate on-device storage for the random values
std::generate(random_dist.begin(), random_dist.end(), [&]() { return dis(gen); }); auto alloc = info.add_instruction(
auto dist_lit = info.add_literal(migraphx::literal{dist_shape, random_dist}); migraphx::make_op("allocate", {{"shape", to_value(compile_shape)}}), alloc_shape);
randoms = info.add_instruction(migraphx::make_op("random_uniform"), seed_input, alloc);
}
else
{
// use literal. The array populated by random_uniform may have any shape, as long its
// number of elements is batch_size * sample_size .
size_t batch_size = s0.lens().front();
auto rand_dummy = info.add_literal(migraphx::literal{
migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
std::vector<float>(batch_size * sample_size)});
randoms =
info.add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
}
return info.add_instruction( return info.add_instruction(
migraphx::make_op("multinomial", {{"dtype", output_type}}), cdf, dist_lit); migraphx::make_op("multinomial", {{"dtype", output_type}}), cdf, randoms);
} }
}; };
......
...@@ -22,14 +22,8 @@ ...@@ -22,14 +22,8 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/pooling.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,68 +33,14 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -39,68 +33,14 @@ struct parse_pooling : op_parser<parse_pooling>
{ {
std::vector<op_desc> operators() const std::vector<op_desc> operators() const
{ {
return {{"AveragePool", "average"}, return {
{"GlobalAveragePool", "average"}, {"AveragePool", "average"},
{"GlobalMaxPool", "max"}, {"GlobalAveragePool", "average"},
{"MaxPool", "max"}, {"GlobalMaxPool", "max"},
{"LpPool", "lpnorm"}, {"MaxPool", "max"},
{"GlobalLpPool", "lpnorm"}}; {"LpPool", "lpnorm"},
} {"GlobalLpPool", "lpnorm"},
};
value handle_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values) const
{
auto kdims = in_shape.ndim() - 2;
if(starts_with(opd.onnx_name, "Global"))
{
// if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
}
else
{
// works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
}
if(contains(info.attributes, "ceil_mode"))
{
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
// lp_order attribute
if(contains(info.attributes, "p"))
{
values["lp_order"] = info.attributes.at("p").i();
}
// ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
return values;
} }
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
...@@ -108,144 +48,8 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -108,144 +48,8 @@ struct parse_pooling : op_parser<parse_pooling>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
std::string mode = opd.op_name; return add_pooling_op(opd, std::move(info), args[0]);
const std::unordered_map<std::string, op::pooling_mode> mode_map = { };
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
}
if(paddings.size() != 2 * kdims)
{
paddings.resize(kdims * 2);
std::fill_n(paddings.begin(), 2 * kdims, 0);
}
if(values["padding"].size() != kdims)
{
values["padding"].resize(kdims);
std::fill_n(values["padding"].begin(), kdims, 0);
}
if(values["stride"].size() != kdims)
{
values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings;
// TODO: add parsing for dilations
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
auto auto_pad = to_upper(info.attributes["auto_pad"].s());
// don't use the given padding sizes, if any
// values["padding"].clear();
if(in_shape.dynamic())
{
// set padding_mode to trigger auto padding at runtime
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
std::vector<size_t>(in_shape.ndim() - 2, 1),
in_shape.lens(),
paddings);
values["padding"] = paddings;
// default padding_mode indicates that padding sizes are not calculated dynamically
values["padding_mode"] = migraphx::op::padding_mode_t::default_;
}
}
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(not slice_start.empty())
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()});
// make an op just to get its output shape
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information
slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2,
out_lens.end(),
slice_start.begin(),
slice_end.begin(),
[](auto i, auto j) { return i + j; });
}
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
op.from_value(values);
auto l1 = info.add_instruction(op, l0);
if(not slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}),
l1);
}
return l1;
}
}; };
} // namespace onnx } // namespace onnx
......
...@@ -36,7 +36,7 @@ namespace onnx { ...@@ -36,7 +36,7 @@ namespace onnx {
/* /*
********************************************************************************* *********************************************************************************
* Reference: see QLinearAdd in * * Reference: see QLinearAdd, QLinearMul in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md * * https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
********************************************************************************* *********************************************************************************
...@@ -49,6 +49,17 @@ namespace onnx { ...@@ -49,6 +49,17 @@ namespace onnx {
This version of the operator has been available since version 1 of the 'com.microsoft' operator This version of the operator has been available since version 1 of the 'com.microsoft' operator
set. set.
com.microsoft.QLinearMul
Performs element-wise binary multiplication on 8 bit data types (with Numpy-style broadcasting
support).
C = ((A - A_zero_point) * (B - B_zero_point)) * (A_scale * B_scale)/C_scale + C_zero_point
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
General definition of binary QLinear* ops:
Inputs (7 - 8) Inputs (7 - 8)
A : T A : T
First operand. First operand.
...@@ -88,15 +99,18 @@ namespace onnx { ...@@ -88,15 +99,18 @@ namespace onnx {
*/ */
struct parse_qlinearadd : op_parser<parse_qlinearadd> struct parse_qlinearbinary : op_parser<parse_qlinearbinary>
{ {
std::vector<op_desc> operators() const { return {{"QLinearAdd"}}; } std::vector<op_desc> operators() const
{
return {{"QLinearAdd", "add"}, {"QLinearMul", "mul"}};
}
// basic type checking for QLinearAdd Operator // basic type checking for binary QLinear Operator
void check_inputs(const std::vector<instruction_ref>& args) const void check_inputs(const std::vector<instruction_ref>& args, const std::string& op_name) const
{ {
if(args.size() < 7) if(args.size() < 7)
MIGRAPHX_THROW("QLINEARADD: missing inputs"); MIGRAPHX_THROW(op_name + ": missing inputs");
const auto& in_a = args[0]; const auto& in_a = args[0];
const auto& in_b = args[3]; const auto& in_b = args[3];
...@@ -107,19 +121,19 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd> ...@@ -107,19 +121,19 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd>
auto type_a = sh_a.type(); auto type_a = sh_a.type();
auto type_b = sh_b.type(); auto type_b = sh_b.type();
if(type_a != migraphx::shape::int8_type and type_a != migraphx::shape::uint8_type) if(type_a != migraphx::shape::int8_type and type_a != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARADD: unsupported input type"); MIGRAPHX_THROW(op_name + ": unsupported input type");
if(type_b != migraphx::shape::int8_type and type_b != migraphx::shape::uint8_type) if(type_b != migraphx::shape::int8_type and type_b != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARADD: unsupported input type"); MIGRAPHX_THROW(op_name + ": unsupported input type");
if(type_a != type_b) if(type_a != type_b)
MIGRAPHX_THROW("QLINEARADD: mismatched input types"); MIGRAPHX_THROW(op_name + ": mismatched input types");
} }
instruction_ref parse(const op_desc& /* opd */, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
check_inputs(args); check_inputs(args, opd.op_name);
// A // A
const auto& in_a = args[0]; const auto& in_a = args[0];
...@@ -134,8 +148,8 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd> ...@@ -134,8 +148,8 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd>
const auto& in_zero_pt_b = args[5]; const auto& in_zero_pt_b = args[5];
auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info); auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info);
// C = A + B // C = op(A, B)
auto out_c = info.add_common_op("add", dquant_a, dquant_b); auto out_c = info.add_common_op(opd.op_name, dquant_a, dquant_b);
const auto& in_scale_c = args[6]; const auto& in_scale_c = args[6];
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -36,90 +37,56 @@ namespace onnx { ...@@ -36,90 +37,56 @@ namespace onnx {
/* /*
********************************************************************************* *********************************************************************************
* Reference: see QLinearGlobalAveragePool in * * Reference: see QLinearAveragePool and QLinearGlobalAveragePool in *
* github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md * * github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
********************************************************************************* *********************************************************************************
*/
QLinearGlobalAveragePool consumes an input tensor X and applies struct parse_qlinearpooling : op_parser<parse_qlinearpooling>
Average pooling across the values in the same channel. This is
equivalent to AveragePool with kernel size equal to the spatial
dimension of input tensor. Input is of type uint8_t or int8_t.
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
Attributes
channels_last : int
Inputs
X : T
Input data tensor from the previous operator; According to channels_last, dimensions for image case
are (N x C x H x W), or (N x H x W x C) where N is the batch size, C is the number of channels, and
H and W are the height and the width of the data. For non image case, the dimensions are in the form
of (N x C x D1 x D2 ... Dn), or (N x D1 X D2 ... Dn x C) where N is the batch size.
x_scale : tensor(float)
Scale of quantized input 'X'. It must be a scalar.
x_zero_point : T
Zero point tensor for input 'X'. It must be a scalar.
y_scale : tensor(float)
Scale of quantized output 'Y'. It must be a scalar.
y_zero_point : T
Zero point tensor for output 'Y'. It must be a scalar.
Outputs
Y : T
Output data tensor from pooling across the input tensor. The output tensor has the same rank as the
input. with the N and C value keep it value, while the other dimensions are all 1. Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to signed/unsigned int8 tensors.
*/
struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool>
{ {
std::vector<op_desc> operators() const { return {{"QLinearGlobalAveragePool"}}; } std::vector<op_desc> operators() const
// basic type checking for QLinearGlobalAveragePool Operator
void check_inputs(const std::vector<instruction_ref>& args) const
{ {
if(args.size() < 5) return {{"QLinearGlobalAveragePool", "average"}, {"QLinearAveragePool", "average"}};
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: missing inputs"); }
const auto& in_x = args[0]; void check_inputs(const op_desc& opd, const std::vector<instruction_ref>& args) const
const auto& zero_pt_x = args[2]; {
const auto& zero_pt_y = args[4]; const auto& in_x = args[0];
const auto onnx_name = opd.onnx_name;
if(in_x->get_shape().ndim() <= 2) if(in_x->get_shape().ndim() <= 2)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: input dimensions too small"); MIGRAPHX_THROW(onnx_name + ": input dimensions too small");
auto type_x = in_x->get_shape().type(); auto type_x = in_x->get_shape().type();
if(type_x != migraphx::shape::int8_type and type_x != migraphx::shape::uint8_type) if(type_x != migraphx::shape::int8_type and type_x != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: unsupported input type"); MIGRAPHX_THROW(onnx_name + ": unsupported input type");
const auto& zero_pt_x = args[2];
if(type_x != zero_pt_x->get_shape().type()) if(type_x != zero_pt_x->get_shape().type())
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: mismatched type: input zero point"); MIGRAPHX_THROW(onnx_name + ": mismatched type: input zero point");
if(type_x != zero_pt_y->get_shape().type()) if(args.size() == 5)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: mismatched type: output zero point"); {
const auto& zero_pt_y = args[4];
if(type_x != zero_pt_y->get_shape().type())
MIGRAPHX_THROW(onnx_name + ": mismatched type: output zero point");
}
} }
instruction_ref parse(const op_desc& /* opd */, instruction_ref parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
int channels_last = if(contains(info.attributes, "channel_last"))
parser.parse_value(info.attributes.at("channels_last")).template at<int>(); {
if(channels_last != 0) int channels_last =
MIGRAPHX_THROW( parser.parse_value(info.attributes.at("channels_last")).template at<int>();
"QLINEARGLOBALAVERAGEPOOL: channels_last (N x D1..Dn x C) is not supported"); if(channels_last != 0)
MIGRAPHX_THROW(opd.onnx_name + ": channels_last (N x D1..Dn x C) is not supported");
}
check_inputs(args); check_inputs(opd, args);
// Input: X // Input: X
...@@ -128,21 +95,18 @@ struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool ...@@ -128,21 +95,18 @@ struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool
const auto& zero_pt_x = args[2]; const auto& zero_pt_x = args[2];
auto dquant_x = bcast_qdq_instr("dequantizelinear", in_x, scale_x, zero_pt_x, info); auto dquant_x = bcast_qdq_instr("dequantizelinear", in_x, scale_x, zero_pt_x, info);
// Output Y = globalaveragepool(X) // Output Y = pooling_op(X)
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = in_x->get_shape().lens();
std::vector<size_t> lengths(lens.begin() + 2, lens.end());
op.lengths = lengths;
op.padding = std::vector<size_t>(lens.size());
auto out_y = info.add_instruction(op, dquant_x);
const auto& scale_y = args[3]; auto out_y = add_pooling_op(opd, info, dquant_x);
const auto& zero_pt_y = args[4];
auto out_quant_y = bcast_qdq_instr("quantizelinear", out_y, scale_y, zero_pt_y, info); const auto& in_scale_y = args[3];
// zero_pt for Y is supplied as the last optional argument..
if(args.size() == 5)
return (bcast_qdq_instr("quantizelinear", out_y, in_scale_y, args[4], info));
return out_quant_y; // if no zero_pt: just broadcast the scale..
auto bcast_scale_y = bcast_scalar_instr(out_y->get_shape(), in_scale_y, info);
return (info.add_instruction(migraphx::make_op("quantizelinear"), out_y, bcast_scale_y));
} }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
/*
*********************************************************************************
* Reference: see QLinearSigmoid, QLinearLeakyRelu in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
com.microsoft.QLinearSigmoid
QLinearSigmoid takes quantized input data (Tensor), and quantize parameter for output, and produces
one output data (Tensor) where the function f(x) = quantize(Sigmoid(dequantize(x))), is applied to
the data tensor elementwise. Where the function Sigmoid(x) = 1 / (1 + exp(-x))
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
*****************************************************************************************************
com.microsoft.QLinearLeakyRelu
QLinearLeakyRelu takes quantized input data (Tensor), an argument alpha, and quantize parameter for
output, and produces one output data (Tensor) where the function f(x) = quantize(alpha *
dequantize(x)) for dequantize(x) < 0, f(x) = quantize(dequantize(x)) for dequantize(x) >= 0, is
applied to the data tensor elementwise.
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
Attributes
alpha : float
Coefficient of leakage.
******************************************************************************************************
Generic input layout of QLinear unary operators:
Inputs (4 - 5)
X : T
Input tensor
X_scale : tensor(float)
Input X's scale. It's a scalar, which means a per-tensor/layer quantization.
X_zero_point (optional) : T
Input X's zero point. Default value is 0 if it's not specified. It's a scalar, which means a
per-tensor/layer quantization.
Y_scale : tensor(float) Output Y's scale. It's a scalar, which means
a per-tensor/layer quantization.
Y_zero_point (optional) : T Output Y's zero point. Default value is
0 if it's not specified. It's a scalar, which means a per-tensor/layer quantization.
Outputs
Y : T
Output tensor
Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to 8 bit tensors.
*/
struct parse_qlinearunary : op_parser<parse_qlinearunary>
{
std::vector<op_desc> operators() const
{
return {{"QLinearSigmoid", "sigmoid"}, {"QLinearLeakyRelu", "leaky_relu"}};
}
void check_inputs(const op_desc& opd, const std::vector<instruction_ref>& args) const
{
if(args.size() < 4)
MIGRAPHX_THROW(opd.op_name + ": missing inputs");
const auto& in_x = args[0];
auto sh_x = in_x->get_shape();
auto type_x = sh_x.type();
if(type_x != migraphx::shape::int8_type and type_x != migraphx::shape::uint8_type)
MIGRAPHX_THROW(opd.op_name + ": unsupported input type");
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
check_inputs(opd, args);
// X
const auto& in_x = args[0];
const auto& in_scale_x = args[1];
const auto& in_zero_pt_x = args[2];
auto dquant_x = bcast_qdq_instr("dequantizelinear", in_x, in_scale_x, in_zero_pt_x, info);
// Y = (op(dequantizelinear(x))
auto op = parser.load(opd.op_name, info);
auto y = info.add_instruction(op, dquant_x);
const auto& in_scale_y = args[3];
// zero_pt for Y is supplied as the last optional argument..
if(args.size() == 5)
return (bcast_qdq_instr("quantizelinear", y, in_scale_y, args[4], info));
// if no zero_pt: just broadcast the scale..
auto bcast_scale_sigm = bcast_scalar_instr(y->get_shape(), in_scale_y, info);
return (info.add_instruction(migraphx::make_op("quantizelinear"), y, bcast_scale_sigm));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment