Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
......@@ -3,24 +3,42 @@
//
// Copyright (c) Facebook Inc. and Microsoft Corporation.
// Copyright (c) ONNX Project Contributors.
// Licensed under the MIT license.
syntax = "proto2";
package onnx;
package onnx_for_migraphx;
// Note [Release]
// Overview
//
// ONNX is an open specification that is comprised of the following components:
//
// 1) A definition of an extensible computation graph model.
// 2) Definitions of standard data types.
// 3) Definitions of built-in operators.
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes
//
// Release
//
// We are still in the very early stage of defining ONNX. The current
// version of ONNX is a starting point. While we are actively working
// towards a complete spec, we would like to get the community involved
// by sharing our working version of ONNX.
// Note [Protobuf compatibility]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Based on experience working with downstream vendors, we generally can't
// assume recent versions of protobufs. This means that we do not use any
// protobuf features that are only available in proto3.
//
// Protobuf compatibility
//
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
// Here are the most notable contortions we have to carry out to work around
// these limitations:
......@@ -29,30 +47,11 @@ package onnx;
// of key-value pairs, where order does not matter and duplicates
// are not allowed.
// Note [Namespaces]
// ~~~~~~~~~~~~~~~~~
// ONNX gives explicit names to graphs, intermediate values and
// serialized tensors. To make it easier to generate names, we organize
// these into separate namespaces (so, e.g., a graph can have the same
// name as a serialized tensor.) The namespaces are as follows:
//
// - Node: These names identify specific nodes in the graph (but not, necessarily
// any particular input or output of the node.
// - Graph: These names identify graphs in the protobuf.
// - Attribute: These names identify attribute names for extra attributes that
// are passed to operators.
// - Operator: These names identify particular operators.
// - Value: These names identify intermediate values (typically tensors) flowing through
// the computation of a graph.
// - Shape: These names represent parameters for unknown shape dimensions.
// Versioning
//
// We specify the namespace of a name in ONNX as comments in the form
// of "namespace {Node,Graph,Operator,Attribute,Value,Shape}". Framework is responsible
// for supporting the namespaces.
// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
//
// Naming things is hard. Every element with a name has an optional doc_string associated
// with it, providing a human-readable description in text markdown.
// To be compatible with both proto2 and proto3, we will use a version number
// that is not defined by the default value but an explicit enum number.
enum Version {
......@@ -61,26 +60,53 @@ enum Version {
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
// control. We should use version as
// xx(major) - xx(minor) - xxxx(bugfix)
// and we are starting with 0x00000001 (0.0.1), which was the
// version we published on Oct 10, 2017.
IR_VERSION_2017_10_10 = 0x00000001;
// control.
// For the IR, we are using simple numbers starting with 0x00000001,
// which was the version we published on Oct 10, 2017.
IR_VERSION_2017_10_10 = 0x0000000000000001;
// IR_VERSION 0.0.2 published on Oct 30, 2017
// IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users
IR_VERSION_2017_10_30 = 0x00000002;
IR_VERSION_2017_10_30 = 0x0000000000000002;
// IR VERSION 0.0.3 published on Nov 3, 2017
// IR VERSION 3 published on Nov 3, 2017
// - For operator versioning:
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
IR_VERSION = 0x00000003;
IR_VERSION_2017_11_3 = 0x0000000000000003;
// IR VERSION 4 published on Jan 22, 2019
// - Relax constraint that initializers should be a subset of graph inputs
// - Add type BFLOAT16
IR_VERSION_2019_1_22 = 0x0000000000000004;
// IR VERSION 5 published on March 18, 2019
// - Add message TensorAnnotation.
// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
IR_VERSION_2019_3_18 = 0x0000000000000005;
// IR VERSION 6 published on Sep 19, 2019
// - Add support for sparse tensor constants stored in model.
// - Add message SparseTensorProto
// - Add sparse initializers
IR_VERSION_2019_9_19 = 0x0000000000000006;
// IR VERSION 7 published on <TBD>
// - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the
// stored models.
// - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables.
// - Make inference graph callable from TrainingInfoProto via GraphCall operator.
IR_VERSION = 0x0000000000000007;
}
// A named attribute containing either singular float, integer, string
// and tensor values, or repeated float, integer, string and tensor values.
// Attributes
//
// A named attribute containing either singular float, integer, string, graph,
// and tensor values, or repeated float, integer, string, graph, and tensor values.
// An AttributeProto MUST contain the name field, and *only one* of the
// following content fields, effectively enforcing a C/C++ union equivalent.
message AttributeProto {
......@@ -94,26 +120,34 @@ message AttributeProto {
STRING = 3;
TENSOR = 4;
GRAPH = 5;
SPARSE_TENSOR = 11;
FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
SPARSE_TENSORS = 12;
}
// The name field MUST be present for this version of the IR.
optional string name = 1; // namespace Attribute
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
optional string ref_attr_name = 21;
// A human-readable documentation for this attribute. Markdown is allowed.
optional string doc_string = 13;
// The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and
// implementations needed to use has_field hueristics to determine
// implementations needed to use has_field heuristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This
// change was made to accomodate proto3 implementations.
// change was made to accommodate proto3 implementations.
optional AttributeType type = 20; // discriminator that indicates which field below is in use
// Exactly ONE of the following fields must be present for this version of the IR
......@@ -122,6 +156,7 @@ message AttributeProto {
optional bytes s = 4; // UTF-8 string
optional TensorProto t = 5; // tensor value
optional GraphProto g = 6; // graph
optional SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
......@@ -130,6 +165,7 @@ message AttributeProto {
repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
}
// Defines information on value, including the name, the type, and
......@@ -137,16 +173,20 @@ message AttributeProto {
message ValueInfoProto {
// This field MUST be present in this version of the IR.
optional string name = 1; // namespace Value
// This field MUST be present in this version of the IR.
// This field MUST be present in this version of the IR for
// inputs and outputs of the top-level graph.
optional TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
optional string doc_string = 3;
}
// NodeProto stores a node that is similar to the notion of "layer"
// or "operator" in many deep learning frameworks. For example, it can be a
// node of type "Conv" that takes in an image, a filter tensor and a bias
// tensor, and produces the convolved output.
// Nodes
//
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated string input = 1; // namespace Value
repeated string output = 2; // namespace Value
......@@ -161,18 +201,125 @@ message NodeProto {
optional string domain = 7; // namespace Domain
// Additional named attributes.
// NOTE: Simply using ValueProto.NameValuePairProto is the most general
// solution. I kept AttributeProto to minimize churn on CI results.
repeated AttributeProto attribute = 5;
// A human-readable documentation for this node. Markdown is allowed.
optional string doc_string = 6;
}
// ModelProto is a top-level file/container format for bundling a ML model.
// The semantics of the model are described by the GraphProto that represents
// a parameterized computation graph against a set of named operators that are
// defined independently from the graph.
// Training information
// TrainingInfoProto stores information for training a model.
// In particular, this defines two functionalities: an initialization-step
// and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been consumed.
// Training algorithm improves the model based on input data.
//
// The semantics of the initialization-step is that the initializers
// in ModelProto.graph and in TrainingInfoProto.algorithm are first
// initialized as specified by the initializers in the graph, and then
// updated by the "initialization_binding" in every instance in
// ModelProto.training_info.
//
// The field "algorithm" defines a computation graph which represents a
// training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains
// consecutive update stages (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each stage.
message TrainingInfoProto {
// This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input
// and can have multiple outputs. Usually, trainable tensors in neural
// networks are randomly initialized. To achieve that, for each tensor,
// the user can put a random number operator such as RandomNormal or
// RandomUniform in TrainingInfoProto.initialization.node and assign its
// random output to the specific tensor using "initialization_binding".
// This graph can also set the initializers in "algorithm" in the same
// TrainingInfoProto; a use case is resetting the number of training
// iteration to zero.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output.
optional GraphProto initialization = 1;
// This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this graph contains loss node, gradient node,
// optimizer node, increment of iteration count, and some calls to the inference
// graph.
//
// The field algorithm.node is the only place the user can use GraphCall
// operator. The only callable graph is the one stored in ModelProto.graph.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output.
optional GraphProto algorithm = 2;
// This field specifies the bindings from the outputs of "initialization" to
// some initializers in "ModelProto.graph.initializer" and
// the "algorithm.initializer" in the same TrainingInfoProto.
// See "update_binding" below for details.
//
// By default, this field is empty and no initializer would be changed
// by the execution of "initialization".
repeated StringStringEntryProto initialization_binding = 3;
// Gradient-based training is usually an iterative procedure. In one gradient
// descent iteration, we apply
//
// x = x - r * g
//
// where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
// gradient of "x" with respect to a chosen loss. To avoid adding assignments
// into the training graph, we split the update equation into
//
// y = x - r * g
// x = y
//
// The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
// tell that "y" should be assigned to "x", the field "update_binding" may
// contain a key-value pair of strings, "x" (key of StringStringEntryProto)
// and "y" (value of StringStringEntryProto).
// For a neural network with multiple trainable (mutable) tensors, there can
// be multiple key-value pairs in "update_binding".
//
// The initializers appears as keys in "update_binding" are considered
// mutable and globally-visible variables. This implies some behaviors
// as described below.
//
// 1. We have only unique keys in all "update_binding"s so that two global
// variables may not have the same name. This ensures that one
// global variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm".
// 4. If an optional input of a graph is omitted when using GraphCall, the
// global variable with the same name may be used.
// 5. When using GraphCall, the users always can pass values to optional
// inputs of the called graph even if the associated initializers appears
// as keys in "update_binding"s.
// 6. The graphs in TrainingInfoProto's can use global variables as
// their operator inputs.
// 7. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
//
// This field usually contains names of trainable tensors
// (in ModelProto.graph), optimizer states such as momentums in advanced
// stochastic gradient methods (in TrainingInfoProto.graph),
// and number of training iterations (in TrainingInfoProto.graph).
//
// By default, this field is empty and no initializer would be changed
// by the execution of "algorithm".
repeated StringStringEntryProto update_binding = 4;
}
// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
// The semantics of the model are described by the associated GraphProto's.
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
......@@ -217,6 +364,17 @@ message ModelProto {
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
// Training-specific information. Sequentially executing all stored
// `TrainingInfoProto.algorithm`s and assigning their outputs following
// the corresponding `TrainingInfoProto.update_binding`s is one training
// iteration. Similarly, to initialize the model
// (as if training hasn't happened), the user should sequentially execute
// all stored `TrainingInfoProto.initialization`s and assigns their outputs
// using `TrainingInfoProto.initialization_binding`s.
//
// If this field is empty, the training behavior of the model is undefined.
repeated TrainingInfoProto training_info = 20;
};
// StringStringEntryProto follows the pattern for cross-proto-version maps.
......@@ -226,25 +384,38 @@ message StringStringEntryProto {
optional string value= 2;
};
// GraphProto defines a parameterized series of nodes to form a directed acyclic graph.
// This is the equivalent of the "network" and "graph" in many deep learning
message TensorAnnotation {
optional string tensor_name = 1;
// <key, value> pairs to annotate tensor specified by <tensor_name> above.
// The keys used in the mapping below must be pre-defined in ONNX spec.
// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
// quantization parameter keys.
repeated StringStringEntryProto quant_parameter_tensor_names = 2;
}
// Graphs
//
// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
message GraphProto {
// The nodes in the graph.
// The nodes in the graph, sorted topologically.
repeated NodeProto node = 1;
// The name of the graph.
optional string name = 2; // namespace Graph
// A list of named tensor values (constants), used to specify default
// values for some of the inputs of the graph.
// A list of named tensor values, used to specify constant inputs of the graph.
// Each TensorProto entry must have a distinct name (within the list) that
// also appears in the input list.
// In an evaluation, the default value specified here is used if and only if
// user specifies no value for the corresponding input parameter.
// May be used to pass serialized parameters for networks.
// MAY also appear in the input list.
repeated TensorProto initializer = 5;
// Initializers (see above) stored in sparse format.
repeated SparseTensorProto sparse_initializer = 15;
// A human-readable documentation for this graph. Markdown is allowed.
optional string doc_string = 10;
......@@ -256,7 +427,13 @@ message GraphProto {
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;
// DO NOT USE the following fields, they were deprecated before
// This field carries information to indicate the mapping among a tensor and its
// quantization parameter tensors. For example:
// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14;
// DO NOT USE the following fields, they were deprecated from earlier versions.
// repeated string input = 3;
// repeated string output = 4;
// optional int64 ir_version = 6;
......@@ -265,7 +442,9 @@ message GraphProto {
// optional string domain = 9;
}
// A message defined to store a tensor in its serialized format.
// Tensors
//
// A serialized tensor value.
message TensorProto {
enum DataType {
UNDEFINED = 0;
......@@ -280,13 +459,21 @@ message TensorProto {
STRING = 8; // string
BOOL = 9; // bool
// Advanced types
// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Future extensions go here.
}
......@@ -294,7 +481,8 @@ message TensorProto {
repeated int64 dims = 1;
// The data type of the tensor.
optional DataType data_type = 2;
// This field MUST have a valid TensorProto.DataType value
optional int32 data_type = 2;
// For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in
......@@ -305,7 +493,7 @@ message TensorProto {
}
optional Segment segment = 3;
// Tensor content must be in the row major order.
// Tensor content must be organized in row-major order.
//
// Depending on the data_type field, exactly one of the fields below with
// name ending in _data is used to store the elements of the tensor.
......@@ -313,7 +501,7 @@ message TensorProto {
// For float and complex64 values
// Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component apparing in the
// and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
......@@ -323,7 +511,7 @@ message TensorProto {
// float16 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32
// INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16
repeated int32 int32_data = 5 [packed = true];
// For strings.
......@@ -360,10 +548,32 @@ message TensorProto {
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
optional bytes raw_data = 9;
// Data can be stored inside the protobuf file using type-specific fields or raw_data.
// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
// external_data stores key-value pairs describing data location. Recognized keys are:
// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
// protobuf model was stored
// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
// - "length" (optional) - number of bytes containing data. Integer stored as string.
// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
repeated StringStringEntryProto external_data = 13;
// Location of the data for this tensor. MUST be one of:
// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
// - EXTERNAL - data stored in an external location as described by external_data field.
enum DataLocation {
DEFAULT = 0;
EXTERNAL = 1;
}
// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
optional DataLocation data_location = 14;
// For double
// Complex64 tensors are encoded as a single array of doubles,
// Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component apparing in the
// and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
......@@ -375,6 +585,28 @@ message TensorProto {
repeated uint64 uint64_data = 11 [packed = true];
}
// A serialized sparse-tensor value
message SparseTensorProto {
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors.
optional TensorProto values = 1;
// The indices of the non-default values, which may be stored in one of two formats.
// (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
// corresponding to the j-th index of the i-th value (in the values tensor).
// (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
// must be the linearized-index of the i-th value (in the values tensor).
// The linearized-index can be converted into an index tuple (k_1,...,k_rank)
// using the shape provided below.
// The indices must appear in ascending order without duplication.
// In the first format, the ordering is lexicographic-ordering:
// e.g., index-value [1,4] must appear before [2,1]
optional TensorProto indices = 2;
// The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
repeated int64 dims = 3;
}
// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
......@@ -384,28 +616,73 @@ message TensorShapeProto {
int64 dim_value = 1;
string dim_param = 2; // namespace Shape
};
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations.
optional string denotation = 3;
};
repeated Dimension dim = 1;
}
// Define the types.
// Types
//
// The standard ONNX data types.
message TypeProto {
message Tensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
optional TensorProto.DataType elem_type = 1;
optional int32 elem_type = 1;
optional TensorShapeProto shape = 2;
}
// repeated T
message Sequence {
// The type and optional shape of each element of the sequence.
// This field MUST be present for this version of the IR.
optional TypeProto elem_type = 1;
};
// map<K,V>
message Map {
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
optional int32 key_type = 1;
// This field MUST be present for this version of the IR.
optional TypeProto value_type = 2;
};
oneof value {
// The type of a tensor.
Tensor tensor_type = 1;
// NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
// as input and output to graphs and nodes. These types are needed to naturally
// support classical ML operators. DNN operators SHOULD restrict their input
// and output types to tensors.
// The type of a sequence.
Sequence sequence_type = 4;
// The type of a map.
Map map_type = 5;
}
// An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations.
optional string denotation = 6;
}
// Operator Sets
//
// OperatorSets are uniquely identified by a (domain, opset_version) pair.
message OperatorSetIdProto {
// The domain of the operator set being identified.
......@@ -418,3 +695,8 @@ message OperatorSetIdProto {
// This field MUST be present in this version of the IR.
optional int64 version = 2;
}
// For using protobuf-lite
option optimize_for = LITE_RUNTIME;
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/fallthrough.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/common.hpp>
#include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
{
std::unordered_map<std::string, onnx::AttributeProto> result;
for(auto&& attr : node.attribute())
{
result[attr.name()] = attr;
}
return result;
}
static literal
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
{
// empty input
auto elem_num =
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0)
{
return {};
}
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
return literal{{shape_type}, data};
return literal{{shape_type, dims}, data};
}
template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
{
// empty input
auto elem_num =
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0)
{
return {};
}
// scalar input
if(dims.empty())
return literal{{shape_type}, data.begin(), data.end()};
return literal{{shape_type, dims}, data.begin(), data.end()};
}
template <class T>
static literal from_repeated(shape::type_t t, const T& r)
{
std::size_t size = r.size();
return literal{{t, {size}}, r.begin(), r.end()};
}
instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const
{
auto attr = ins->get_operator().to_value();
std::string key = "require_std_shape";
if((attr.get(key, false)) or (not ins->get_shape().standard()))
{
return add_instruction(make_op("contiguous"), ins);
}
return ins;
}
instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_ref>& args,
instruction_ref curr_ins,
uint64_t axis) const
{
if(args.size() == 3)
{
auto bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}),
args[2]);
return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
}
return curr_ins;
}
instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0,
instruction_ref arg1) const
{
return this->add_common_op(op_name, arg0, arg1);
}
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const
{
return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs));
}
instruction_ref
onnx_parser::node_info::add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const
{
return mod->add_instruction(op, args);
}
instruction_ref onnx_parser::node_info::add_instruction(const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods) const
{
return mod->add_instruction(op, args, mods);
}
instruction_ref onnx_parser::node_info::add_literal(literal l) const
{
return mod->add_literal(std::move(l));
}
onnx_parser::onnx_parser()
{
// Add all registered op parsers
for(auto&& name : get_op_parsers())
ops.emplace(name, get_op_parser(name));
}
operation onnx_parser::load(const std::string& name, const node_info& info) const
{
auto op = make_op(name);
auto v = op.to_value();
for(auto&& x : v)
{
if(info.attributes.count(x.get_key()) == 0)
continue;
literal s = parse_value(info.attributes.at(x.get_key()));
if(x.is_array())
{
std::vector<value> values;
s.visit([&](auto y) {
std::transform(y.begin(), y.end(), std::back_inserter(values), [](auto z) {
return value(z);
});
});
x = values;
}
else
{
s.visit([&](auto y) { x = y.front(); });
}
}
op.from_value(v);
return op;
}
void onnx_parser::parse_undefined(module* mod, const std::string& name)
{
if(!contains(instructions, name))
{
auto ins = mod->add_instruction(make_op("undefined"));
instructions[name] = ins;
}
}
void onnx_parser::parse_from(std::istream& is, std::string name)
{
auto* mm = prog.get_main_module();
this->filename = std::move(name);
auto parent_path = fs::path(this->filename).parent_path();
if(not parent_path.empty())
this->path = parent_path;
onnx::ModelProto model;
if(model.ParseFromIstream(&is))
{
auto version = get_opset_version(model);
opset_version = (version == -1) ? opset_version : version;
if(model.has_graph())
{
this->parse_graph(mm, model.graph());
}
}
else
{
MIGRAPHX_THROW("PARSE_FROM: Failed reading onnx file: " + this->filename);
}
}
void onnx_parser::parse_from(const void* data, std::size_t size)
{
auto* mm = prog.get_main_module();
onnx::ModelProto model;
if(model.ParseFromArray(data, size))
{
auto version = get_opset_version(model);
opset_version = (version == -1) ? opset_version : version;
if(model.has_graph())
{
this->parse_graph(mm, model.graph());
}
}
else
{
MIGRAPHX_THROW("Failed reading onnx file.");
}
}
int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
{
const auto& opset_import = model.opset_import();
int64_t version = -1;
for(const auto& opset : opset_import)
{
if(opset.has_version())
{
version = std::max(version, opset.version());
}
}
return version;
}
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{
std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer())
{
// backup instructions in parent mod
mod_insts[f.name()] = mod->add_literal(parse_tensor(f));
}
for(auto&& input : graph.input())
{
const std::string& name = input.name();
// input not in initializer_data, so it is a real input
if(!contains(mod_insts, name))
{
// ONNX specification does not specify hwo to deal with the
// scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that.
if(contains(instructions, name))
{
MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name +
"\" existing in parent graph!");
}
std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0)
{
dims = map_input_dims.at(name);
}
shape s = parse_type(input.type(), dims);
mod_insts[name] = mod->add_parameter(name, s);
}
}
std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end()));
for(auto&& node : graph.node())
{
std::vector<instruction_ref> args;
for(auto&& input : node.input())
{
if(input.empty())
{
this->parse_undefined(mod, input);
}
if(instructions.count(input) == 0)
{
MIGRAPHX_THROW("PARSE_GRAPH: invalid onnx file. Input \"" + input +
"\" is unavailable due to unordered nodes!");
}
args.push_back(instructions.at(input));
}
std::vector<instruction_ref> result;
std::size_t output_num = static_cast<std::size_t>(node.output().size());
if(ops.count(node.op_type()) == 0)
{
if(skip_unknown_operators)
result.push_back(mod->add_instruction(op::unknown{node.op_type()}, args));
else
MIGRAPHX_THROW("Unknown operator: " + node.op_type());
}
else
{
std::string node_name = node.op_type() + "_" + std::to_string(mod->size());
result = ops[node.op_type()](
*this, {get_attributes(node), output_num, node_name, mod}, args);
}
output_num = std::min<std::size_t>(output_num, result.size());
std::transform(node.output().begin(),
node.output().begin() + output_num,
result.begin(),
std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(x, y); });
}
// Find instructions corresponding to the output
auto prog_output = graph.output();
std::vector<std::string> all_output_names;
std::vector<std::string> prog_output_names;
std::transform(prog_output.begin(),
prog_output.end(),
std::back_inserter(all_output_names),
[](auto& node) { return node.name(); });
std::copy_if(
all_output_names.begin(),
all_output_names.end(),
std::back_inserter(prog_output_names),
[&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });
std::vector<instruction_ref> output_ins;
std::transform(prog_output_names.begin(),
prog_output_names.end(),
std::back_inserter(output_ins),
[&](const auto& name) { return instructions[name]; });
// add the return instuction
mod->add_return(output_ins);
// remove instructions added in this mod
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
}
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
{
switch(attr.type())
{
case onnx::AttributeProto::FLOAT: return literal{attr.f()};
case onnx::AttributeProto::INT: return literal{attr.i()};
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
case onnx::AttributeProto::UNDEFINED:
case onnx::AttributeProto::GRAPH:
case onnx::AttributeProto::STRING:
case onnx::AttributeProto::STRINGS:
case onnx::AttributeProto::TENSORS:
case onnx::AttributeProto::SPARSE_TENSOR:
case onnx::AttributeProto::SPARSE_TENSORS:
case onnx::AttributeProto::GRAPHS: return {};
}
MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type()));
}
literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
{
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
if(not t.external_data().empty())
{
const std::string& data_file = t.external_data().at(0).value();
auto raw_buffer = read_buffer(path + "/" + data_file);
std::string s(raw_buffer.begin(), raw_buffer.end());
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data());
}
if(t.has_raw_data())
{
const std::string& s = t.raw_data();
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data());
}
switch(t.data_type())
{
case onnx::TensorProto::BOOL: return create_literal(shape::bool_type, dims, t.int32_data());
case onnx::TensorProto::INT8: return create_literal(shape::int8_type, dims, t.int32_data());
case onnx::TensorProto::UINT8: return create_literal(shape::uint8_type, dims, t.int32_data());
case onnx::TensorProto::INT16: return create_literal(shape::int16_type, dims, t.int32_data());
case onnx::TensorProto::UINT16: return create_literal(shape::uint16_type, dims, t.int32_data());
case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::UINT32:
return create_literal(shape::uint32_type, dims, t.uint64_data());
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::UINT64:
return create_literal(shape::uint64_type, dims, t.uint64_data());
case onnx::TensorProto::FLOAT16: {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return create_literal(shape::half_type, dims, data_half);
}
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::STRING:
case onnx::TensorProto::COMPLEX64:
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
}
MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type");
}
shape onnx_parser::parse_type(const onnx::TypeProto& t,
const std::vector<std::size_t>& input_dims) const
{
shape::type_t shape_type = get_type(t.tensor_type().elem_type());
if(!input_dims.empty())
{
return {shape_type, input_dims};
}
std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(tensor_dims.begin(),
tensor_dims.end(),
std::back_inserter(dims),
[&](auto&& d) -> std::size_t {
if(d.has_dim_value())
{
if(static_cast<int>(d.dim_value()) <= 0)
{
return default_dim_value;
}
return d.dim_value();
}
else
{
return default_dim_value;
}
});
if(dims.empty())
return {shape_type};
return {shape_type, dims};
}
shape::type_t get_type(int dtype)
{
switch(dtype)
{
case 1: return shape::float_type;
case 2: return shape::uint8_type;
case 3: return shape::int8_type;
case 4: return shape::uint16_type;
case 5: return shape::int16_type;
case 6: return shape::int32_type;
case 7: return shape::int64_type;
case 9: return shape::bool_type;
case 10: return shape::half_type;
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
}
}
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
std::unordered_map<std::string, onnx_parser::op_func>& op_parser_map()
{
static std::unordered_map<std::string, onnx_parser::op_func> m; // NOLINT
return m;
}
void register_op_parser(const std::string& name, onnx_parser::op_func f)
{
op_parser_map()[name] = std::move(f);
}
onnx_parser::op_func get_op_parser(const std::string& name) { return op_parser_map().at(name); }
std::vector<std::string> get_op_parsers()
{
std::vector<std::string> result;
std::transform(op_parser_map().begin(),
op_parser_map().end(),
std::back_inserter(result),
[&](auto&& p) { return p.first; });
return result;
}
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/padding.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
void cal_auto_padding_size(onnx_parser::node_info info,
value& v,
const std::vector<std::size_t>& k_lens,
const std::vector<std::size_t>& dilation,
const std::vector<std::size_t>& in_lens,
std::vector<int64_t>& paddings)
{
size_t kdims = in_lens.size() - 2;
assert(k_lens.size() == kdims and dilation.size() == kdims);
if(!contains(info.attributes, "auto_pad"))
{
return;
}
auto auto_pad = info.attributes["auto_pad"].s();
if(auto_pad.find("SAME") != std::string::npos)
{
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
paddings.resize(2 * kdims);
for(size_t i = 0; i < paddings.size() / 2; i++)
{
calculate_padding(i,
paddings,
in_lens[i + 2],
v["stride"][i].to<int64_t>(),
dilation[i],
k_lens[i],
is_same_upper);
}
}
}
bool is_asym_padding(const std::vector<int64_t>& padding)
{
assert(padding.size() % 2 == 0);
size_t pad_ndims = padding.size() / 2;
for(size_t i = 0; i < pad_ndims; i++)
{
if(padding[i] != padding[i + pad_ndims])
{
return true;
}
}
return false;
}
void check_padding_mode(const onnx_parser::node_info& info, const std::string& op_name)
{
// ensure pads availabe only when auto_pad is "NOT_SET"
if(contains(info.attributes, "pads") and contains(info.attributes, "auto_pad"))
{
auto s = info.attributes.at("auto_pad").s();
if(to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("PARSE_" + op_name +
": auto_pad and padding cannot be specified simultaneously");
}
}
}
static void
tune_padding_to_symmetric(int64_t& left, int64_t& right, const int stride, int64_t& s_start)
{
s_start = 0;
if(left > right)
{
right = left;
}
else if(left < right)
{
auto diff = right - left;
s_start = (diff + stride - 1) / stride;
left = left + s_start * stride;
right = left;
}
}
void tune_padding_size(const value& v,
std::vector<int64_t>& padding,
int count_include_pad,
std::vector<int64_t>& s_start)
{
// maxpooling or count_include_pad is 1, no change is required.
if(v.at("mode").to<op::pooling_mode>() == op::pooling_mode::max or count_include_pad == 1)
{
return;
}
// if padding is symmetric, return directly
if(!is_asym_padding(padding))
{
return;
}
// asymmetric padding, make it symmetric
std::size_t n_dims = padding.size() / 2;
s_start.resize(n_dims);
for(std::size_t i = 0; i < n_dims; ++i)
{
tune_padding_to_symmetric(
padding[i], padding[i + n_dims], v.at("stride")[i].to<int64_t>(), s_start[i]);
}
}
void check_asym_padding(const onnx_parser::node_info& info,
instruction_ref& ins,
const std::vector<int64_t>& padding,
value& v,
int count_include_pad,
float pad_val)
{
size_t pad_ndims = padding.size() / 2;
auto left_pad_it = padding.begin();
auto right_pad_it = left_pad_it + pad_ndims;
if(count_include_pad == 1)
{
std::vector<int64_t> asym_pads{0, 0, 0, 0}; // don't pad N and C
// add left pads
asym_pads.insert(asym_pads.begin() + 2, left_pad_it, right_pad_it);
// add right pads
asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end());
ins = info.add_instruction(make_op("pad", {{"pads", asym_pads}, {"value", pad_val}}), ins);
std::vector<size_t> new_padding(padding.size());
// subtract asym padding originally found from parsing the operator
std::transform(padding.begin(),
left_pad_it,
asym_pads.begin() + 2,
new_padding.begin(),
std::minus<size_t>());
std::transform(right_pad_it,
padding.end(),
asym_pads.begin() + pad_ndims + 4,
new_padding.begin() + pad_ndims,
std::minus<size_t>());
v["padding"] = new_padding;
}
}
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_arg_op : op_parser<parse_arg_op>
{
std::vector<op_desc> operators() const { return {{"ArgMax", "argmax"}, {"ArgMin", "argmin"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
int64_t axis = 0;
if(contains(info.attributes, "axis"))
{
axis = static_cast<int64_t>(parser.parse_value(info.attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(info.attributes, "keepdims"))
{
keep_dims = parser.parse_value(info.attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args);
return info.add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
}
else
{
return info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args);
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
enum class reduce_mode_t
{
sum = 0,
mean = 1,
max = 2
};
struct parse_aten : op_parser<parse_aten>
{
std::vector<op_desc> operators() const { return {{"ATen"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(contains(info.attributes, "operator"))
{
auto op_name = info.attributes.at("operator").s();
if(op_name.find("embedding_bag") != std::string::npos)
{
return parse_embedding_bag(info, std::move(args));
}
}
MIGRAPHX_THROW("PARSE_ATEN: unsupported custom operator");
}
instruction_ref parse_embedding_bag(onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(args[2]->get_shape().elements() != 1)
MIGRAPHX_THROW("PARSE_EMBEDDING_BAG: MIGraphX only supports offsets of size 1");
reduce_mode_t reduce_mode = reduce_mode_t::sum;
if(contains(info.attributes, "mode"))
{
reduce_mode = static_cast<reduce_mode_t>(info.attributes.at("mode").i());
}
auto l0 = info.add_instruction(make_op("gather"), args[0], args[1]);
switch(reduce_mode)
{
case reduce_mode_t::sum:
l0 = info.add_instruction(make_op("reduce_sum", {{"axes", {0}}}), l0);
break;
case reduce_mode_t::mean:
l0 = info.add_instruction(make_op("reduce_mean", {{"axes", {0}}}), l0);
break;
case reduce_mode_t::max:
l0 = info.add_instruction(make_op("reduce_max", {{"axes", {0}}}), l0);
break;
}
return l0;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_batchnorm : op_parser<parse_batchnorm>
{
std::vector<op_desc> operators() const { return {{"BatchNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
float epsilon = 1e-5f;
float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(info.attributes, "momentum"))
{
momentum = parser.parse_value(info.attributes.at("momentum")).at<float>();
}
if(contains(info.attributes, "spatial"))
{
bn_mode = (parser.parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
? op::batch_norm_inference::spatial
: op::batch_norm_inference::per_activation;
}
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return info.add_instruction(op, args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_binary_op : op_parser<parse_binary_op>
{
std::vector<op_desc> operators() const
{
return {{"Add", "add"},
{"Div", "div"},
{"And", "logical_and"},
{"Or", "logical_or"},
{"Xor", "logical_xor"},
{"Mul", "mul"},
{"PRelu", "prelu"},
{"Sub", "sub"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
{
uint64_t broadcasted =
parser.parse_value(info.attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction(
make_op("broadcast",
{{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}),
args[1]);
return info.add_instruction(make_op(opd.op_name), args[0], l);
}
return info.add_instruction(make_op(opd.op_name), args);
}
else
{
return info.add_broadcastable_binary_op(opd.op_name, args[0], args[1]);
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_cast : op_parser<parse_cast>
{
std::vector<op_desc> operators() const { return {{"Cast"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
if(!contains(info.attributes, "to"))
{
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
}
int to_type = parser.parse_value(info.attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type);
return info.add_instruction(make_op("convert", {{"target_type", type}}), args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_celu : op_parser<parse_celu>
{
std::vector<op_desc> operators() const { return {{"Celu"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
{
alpha = info.attributes.at("alpha").f();
}
if(float_equal(alpha, 0.0f))
{
MIGRAPHX_THROW("CELU: alpha is zero (division by zero)");
}
auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type();
if(input_type != migraphx::shape::float_type)
{
MIGRAPHX_THROW("CELU: input tensor not float type");
}
auto zero_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto alpha_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto linear_part = info.add_instruction(migraphx::make_op("max"), zero_lit, args[0]);
auto divi = info.add_instruction(migraphx::make_op("div"), args[0], alpha_lit);
auto expo = info.add_instruction(migraphx::make_op("exp"), divi);
auto sub = info.add_instruction(migraphx::make_op("sub"), expo, one_lit);
auto mul = info.add_instruction(migraphx::make_op("mul"), alpha_lit, sub);
auto exp_part = info.add_instruction(migraphx::make_op("min"), zero_lit, mul);
return info.add_instruction(migraphx::make_op("add"), linear_part, exp_part);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_clip : op_parser<parse_clip>
{
std::vector<op_desc> operators() const { return {{"Clip"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
instruction_ref min_arg;
instruction_ref max_arg;
bool min_used = false;
bool max_used = false;
if(args.size() == 3 and args[2]->name() != "undefined")
{
max_arg = args[2];
max_used = true;
}
if(args.size() >= 2 and args[1]->name() != "undefined")
{
min_arg = args[1];
min_used = true;
}
// if using previous opset for attributes
else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
{
float min_val = parser.parse_value(info.attributes.at("min")).at<float>();
float max_val = parser.parse_value(info.attributes.at("max")).at<float>();
min_arg = info.add_literal(min_val);
max_arg = info.add_literal(max_val);
min_used = true;
max_used = true;
}
if(min_used and max_used)
{
return info.add_common_op("clip", args[0], min_arg, max_arg);
}
else if(max_used)
{
return info.add_broadcastable_binary_op("min", args[0], max_arg);
}
else if(min_used)
{
return info.add_broadcastable_binary_op("max", args[0], min_arg);
}
else
{
return info.add_instruction(make_op("identity"), args[0]);
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_compare_op : op_parser<parse_compare_op>
{
std::vector<op_desc> operators() const
{
return {{"Equal", "equal"}, {"Greater", "greater"}, {"Less", "less"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto l = info.add_broadcastable_binary_op(opd.op_name, args[0], args[1]);
if(l->get_shape().type() != shape::bool_type)
{
l = info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), l);
}
return l;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_constant : op_parser<parse_constant>
{
std::vector<op_desc> operators() const { return {{"Constant"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& /*args*/) const
{
literal v = parser.parse_value(info.attributes.at("value"));
// return empty literal
if(v.get_shape().elements() == 0)
{
return info.add_literal(literal{});
}
auto dim_size = info.attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
{
migraphx::shape scalar_shape{v.get_shape().type()};
return info.add_literal(migraphx::literal{scalar_shape, v.data()});
}
return info.add_literal(v);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
struct parse_constant_fill : op_parser<parse_constant_fill>
{
std::vector<op_desc> operators() const { return {{"ConstantFill"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int input_as_shape = 0;
int dtype = 1;
float value = 0.0f;
if(contains(info.attributes, "dtype"))
{
dtype = parser.parse_value(info.attributes.at("dtype")).at<int>();
}
shape::type_t type = get_type(dtype);
if(contains(info.attributes, "input_as_shape"))
{
input_as_shape = parser.parse_value(info.attributes.at("input_as_shape")).at<int>();
}
if(contains(info.attributes, "value"))
{
value = parser.parse_value(info.attributes.at("value")).at<float>();
}
if(contains(info.attributes, "extra_shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
}
if(input_as_shape == 1)
{
if(args.size() != 1)
{
MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
}
if(contains(info.attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
"at the same time");
}
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims);
std::vector<float> values(s.elements(), value);
return info.add_literal(migraphx::literal(s, values));
}
else if(input_as_shape == 0)
{
if(!contains(info.attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
}
literal ls = parser.parse_value(info.attributes.at("shape"));
std::vector<std::size_t> dims;
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims};
std::vector<float> values(s.elements(), value);
return info.add_literal(migraphx::literal(s, values));
}
else
{
MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
{
std::vector<op_desc> operators() const { return {{"ConstantOfShape"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
literal l_val{};
if(contains(info.attributes, "value"))
{
l_val = parser.parse_value(info.attributes.at("value"));
if(l_val.get_shape().elements() != 1)
{
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
}
}
else
{
l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
}
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
if(args.empty())
{
MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
}
else
{
migraphx::shape s;
// empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
}
else
{
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims};
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
});
return info.add_literal(l_out);
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/conv.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_convolution : op_parser<parse_convolution>
{
std::vector<op_desc> operators() const
{
return {{"Conv", "convolution"}, {"ConvInteger", "quant_convolution"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto op = make_op(opd.op_name);
auto values = op.to_value();
auto l0 = args[0];
auto weights = args[1];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "CONV");
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_CONV: inconsistent strides");
}
if(contains(info.attributes, "dilations"))
{
values["dilation"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"]));
check_attr_sizes(
kdims, values["dilation"].size(), "PARSE_CONV: inconsistent dilations");
}
std::vector<int64_t> padding;
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(padding));
check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings");
}
if(contains(info.attributes, "auto_pad"))
{
auto weight_lens = weights->get_shape().lens();
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
cal_auto_padding_size(info,
values,
k_lens,
values["dilation"].to_vector<std::size_t>(),
in_lens,
padding);
auto auto_pad = info.attributes["auto_pad"].s();
if(auto_pad.find("SAME") != std::string::npos)
{
values["padding_mode"] = to_value(op::padding_mode_t::same);
}
}
values["padding"] = std::vector<size_t>(padding.begin(), padding.end());
if(contains(info.attributes, "group"))
{
values["group"] = parser.parse_value(info.attributes.at("group")).at<int>();
}
recalc_conv_attributes(values, kdims);
op.from_value(values);
auto l1 = info.add_instruction(op, l0, args[1]);
return info.add_bias(args, l1, 1);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/conv.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
template <class T>
std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
{
std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
return output_vector;
}
struct parse_deconvolution : op_parser<parse_deconvolution>
{
std::vector<op_desc> operators() const { return {{"ConvTranspose"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
operation op = make_op("deconvolution");
value values = op.to_value();
// op::deconvolution op;
auto l0 = args[0];
std::vector<std::int64_t> padding;
bool asym_padding = false;
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "CONV_TRANSPOSE");
if(contains(info.attributes, "pads"))
{
copy(info.attributes["pads"].ints(), std::back_inserter(padding));
asym_padding = is_asym_padding(padding);
if(not asym_padding)
{
size_t pad_ndims = padding.size() / 2;
check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings");
values["padding"].clear();
std::transform(padding.begin(),
padding.begin() + pad_ndims,
std::back_inserter(values["padding"]),
[](auto pad_val) { return pad_val; });
}
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(
kdims, values["stride"].size(), "PARSE_CONV_TRANSPOSE: inconsistent strides");
}
if(contains(info.attributes, "dilations"))
{
values["dilation"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"]));
check_attr_sizes(
kdims, values["dilation"].size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations");
}
if(contains(info.attributes, "auto_pad"))
{
auto s = info.attributes["auto_pad"].s();
if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: auto_pad and padding cannot be specified "
"simultaneously");
}
if(s.find("SAME") != std::string::npos)
{
values["padding_mode"] = to_value(op::padding_mode_t::same);
}
}
if(contains(info.attributes, "group"))
{
values["group"] = parser.parse_value(info.attributes.at("group")).at<int>();
}
recalc_conv_attributes(values, kdims);
op.from_value(values);
auto l1 = info.add_instruction(op, l0, args[1]);
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
if(asym_padding)
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); // ignore first 2 dims
auto pad_kdim_start = padding.begin() + kdims;
std::vector<int64_t> starts(padding.begin(), pad_kdim_start);
std::vector<int64_t> ends{};
std::transform(curr_shape.begin(),
curr_shape.end(),
pad_kdim_start,
std::back_inserter(ends),
[](auto curr_dim, auto pad_dim) { return curr_dim - pad_dim; });
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), l1);
}
if(contains(info.attributes, "output_padding"))
{
size_t non_kdims = dims.size() * 2 - kdims;
std::vector<int64_t> output_padding(non_kdims, 0);
copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
check_attr_sizes(kdims,
output_padding.size() - non_kdims,
"PARSE_CONV_TRANSPOSE: inconsistent output padding");
l1 = info.add_instruction(make_op("pad", {{"pads", output_padding}}), l1);
}
if(contains(info.attributes, "output_shape"))
{
std::vector<int64_t> output_shape;
copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
check_attr_sizes(
kdims, output_shape.size(), "PARSE_CONV_TRANSPOSE: inconsistent output shape");
dims = to_int64_vector(l1->get_shape().lens());
copy(dims.begin() + 2, dims.end(), curr_shape.begin());
if(curr_shape != output_shape)
{
std::vector<int64_t> target_padding(dims.size() * 2 - kdims, 0);
std::transform(output_shape.begin(),
output_shape.end(),
curr_shape.begin(),
std::back_inserter(target_padding),
[](auto out_dim, auto curr_dim) { return out_dim - curr_dim; });
l1 = info.add_instruction(make_op("pad", {{"pads", target_padding}}), l1);
}
}
return info.add_bias(args, l1, 1);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_depthtospace : op_parser<parse_depthtospace>
{
std::vector<op_desc> operators() const { return {{"DepthToSpace"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto s = args[0]->get_shape();
// mode attribute of DepthToSpace
auto mode = std::string("DCR");
if(contains(info.attributes, "mode"))
{
mode = info.attributes.at("mode").s(); // DCR or CRD?
}
// blocksize attribute of DepthToSpace
int blocksize = 0;
if(contains(info.attributes, "blocksize"))
{
blocksize = info.attributes.at("blocksize").i();
}
if(blocksize < 1)
{
MIGRAPHX_THROW("DepthToSpace: blocksize is less than 1");
}
// calculate dimensions
auto lens1 = s.lens();
auto lens2 = s.lens();
unsigned long divisor = std::pow(blocksize, 2);
if((lens2[1] % divisor) == 0)
lens2[1] = lens2[1] / divisor;
else
MIGRAPHX_THROW("DepthToSpace: div by blocksize quotient not int ");
lens1.push_back(lens1[2]);
lens1.push_back(lens1[3]);
lens2[2] = lens2[2] * blocksize;
lens2[3] = lens2[3] * blocksize;
lens1[2] = blocksize;
std::vector<int64_t> perm;
if(mode == "DCR")
{
lens1[3] = lens1[1] / divisor;
lens1[1] = blocksize;
perm = {0, 3, 4, 1, 5, 2};
}
else if(mode == "CRD")
{
lens1[1] = lens1[1] / divisor;
lens1[3] = blocksize;
perm = {0, 1, 4, 2, 5, 3};
}
else
MIGRAPHX_THROW("DepthToSpace: mode attribute cannot be read.");
auto temp1 = info.add_instruction(make_op("reshape", {{"dims", lens1}}), args[0]);
auto temp2 = info.add_instruction(make_op("transpose", {{"permutation", perm}}), temp1);
return info.add_instruction(make_op("reshape", {{"dims", lens2}}),
info.make_contiguous(temp2));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{
std::vector<op_desc> operators() const { return {{"DequantizeLinear"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
int axis = 1;
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens();
auto n_dim = input_lens.size();
instruction_ref x_scale;
if(args[1]->get_shape().elements() != 1)
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
x_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
if(args.size() == 3)
{
auto x_zero_point = args[2];
if(x_zero_point->get_shape().elements() != 1)
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
x_zero_point);
}
else
{
x_zero_point = info.add_instruction(
make_op("multibroadcast", {{"out_lens", input_lens}}), x_zero_point);
}
return info.add_instruction(
make_op("dequantizelinear"), args[0], x_scale, x_zero_point);
}
return info.add_instruction(make_op("dequantizelinear"), args[0], x_scale);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_dropout : op_parser<parse_dropout>
{
std::vector<op_desc> operators() const { return {{"Dropout"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto out = info.add_instruction(make_op("identity"), args[0]);
auto s = args[0]->get_shape();
std::vector<int8_t> vec(s.elements(), 1);
shape mask_s{shape::bool_type, s.lens()};
auto mask = info.add_literal(literal(mask_s, vec));
return {out, mask};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment