Commit 6f768035 authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'rocblas_mlir_fp8' into miopen_fp8

parents da7717ce b2542239
...@@ -21,31 +21,26 @@ ...@@ -21,31 +21,26 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/shape.hpp> #include <migraphx/onnx/onnx_parser.hpp>
#include <set> #include <migraphx/onnx/op_parser.hpp>
#include <string> #include <migraphx/instruction.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct module; value handle_pooling_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values);
/** instruction_ref add_pooling_op(const op_desc& opd, onnx_parser::node_info info, instruction_ref l0);
This will insert convert operators for the operators that are not implemented for FP8 dtypes
*/
struct MIGRAPHX_EXPORT eliminate_fp8
{
// TODO: Add all device ops as a later PR and add tests for those.
std::set<std::string> op_names;
shape::type_t target_type = migraphx::shape::float_type;
std::string name() const { return "eliminate_fp8"; }
void apply(module& m) const;
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
// //
// Copyright (c) ONNX Project Contributors. // SPDX-License-Identifier: Apache-2.0
// Licensed under the MIT license.
syntax = "proto2"; syntax = "proto2";
...@@ -27,13 +27,6 @@ package onnx_for_migraphx; ...@@ -27,13 +27,6 @@ package onnx_for_migraphx;
// Notes // Notes
// //
// Release
//
// We are still in the very early stage of defining ONNX. The current
// version of ONNX is a starting point. While we are actively working
// towards a complete spec, we would like to get the community involved
// by sharing our working version of ONNX.
//
// Protobuf compatibility // Protobuf compatibility
// //
// To simplify framework compatibility, ONNX is defined using the subset of protobuf // To simplify framework compatibility, ONNX is defined using the subset of protobuf
...@@ -92,15 +85,28 @@ enum Version { ...@@ -92,15 +85,28 @@ enum Version {
// - Add sparse initializers // - Add sparse initializers
IR_VERSION_2019_9_19 = 0x0000000000000006; IR_VERSION_2019_9_19 = 0x0000000000000006;
// IR VERSION 7 published on <TBD> // IR VERSION 7 published on May 8, 2020
// - Add support to allow function body graph to rely on multiple external opreator sets.
// - Add a list to promote inference graph's initializers to global and // - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the // mutable variables. Global variables are visible in all graphs of the
// stored models. // stored models.
// - Add message TrainingInfoProto to store initialization // - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto // method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables. // can modify the values of mutable variables.
// - Make inference graph callable from TrainingInfoProto via GraphCall operator. // - Implicitly add inference graph into each TrainingInfoProto's algorithm.
IR_VERSION = 0x0000000000000007; IR_VERSION_2020_5_8 = 0x0000000000000007;
// IR VERSION 8 published on July 30, 2021
// Introduce TypeProto.SparseTensor
// Introduce TypeProto.Optional
// Added a list of FunctionProtos local to the model
// Deprecated since_version and operator status from FunctionProto
IR_VERSION_2021_7_30 = 0x0000000000000008;
// IR VERSION 9 published on TBD
// Added AttributeProto to FunctionProto so that default attribute values can be set.
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
IR_VERSION = 0x0000000000000009;
} }
// Attributes // Attributes
...@@ -121,6 +127,7 @@ message AttributeProto { ...@@ -121,6 +127,7 @@ message AttributeProto {
TENSOR = 4; TENSOR = 4;
GRAPH = 5; GRAPH = 5;
SPARSE_TENSOR = 11; SPARSE_TENSOR = 11;
TYPE_PROTO = 13;
FLOATS = 6; FLOATS = 6;
INTS = 7; INTS = 7;
...@@ -128,6 +135,7 @@ message AttributeProto { ...@@ -128,6 +135,7 @@ message AttributeProto {
TENSORS = 9; TENSORS = 9;
GRAPHS = 10; GRAPHS = 10;
SPARSE_TENSORS = 12; SPARSE_TENSORS = 12;
TYPE_PROTOS = 14;
} }
// The name field MUST be present for this version of the IR. // The name field MUST be present for this version of the IR.
...@@ -159,6 +167,7 @@ message AttributeProto { ...@@ -159,6 +167,7 @@ message AttributeProto {
optional SparseTensorProto sparse_tensor = 22; // sparse tensor value optional SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated. // Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph // optional ValueProto v = 12; // value - subsumes everything but graph
optional TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints repeated int64 ints = 8; // list of ints
...@@ -166,6 +175,7 @@ message AttributeProto { ...@@ -166,6 +175,7 @@ message AttributeProto {
repeated TensorProto tensors = 10; // list of tensors repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph repeated GraphProto graphs = 11; // list of graph
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
repeated TypeProto type_protos = 15;// list of type protos
} }
// Defines information on value, including the name, the type, and // Defines information on value, including the name, the type, and
...@@ -211,7 +221,7 @@ message NodeProto { ...@@ -211,7 +221,7 @@ message NodeProto {
// TrainingInfoProto stores information for training a model. // TrainingInfoProto stores information for training a model.
// In particular, this defines two functionalities: an initialization-step // In particular, this defines two functionalities: an initialization-step
// and a training-algorithm-step. Initialization resets the model // and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been consumed. // back to its original state as if no training has been performed.
// Training algorithm improves the model based on input data. // Training algorithm improves the model based on input data.
// //
// The semantics of the initialization-step is that the initializers // The semantics of the initialization-step is that the initializers
...@@ -224,8 +234,8 @@ message NodeProto { ...@@ -224,8 +234,8 @@ message NodeProto {
// training algorithm's step. After the execution of a // training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding" // TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains // may be immediately updated. If the targeted training algorithm contains
// consecutive update stages (such as block coordinate descent methods), // consecutive update steps (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each stage. // the user needs to create a TrainingInfoProto for each step.
message TrainingInfoProto { message TrainingInfoProto {
// This field describes a graph to compute the initial tensors // This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input // upon starting the training process. Initialization graph has no input
...@@ -239,20 +249,38 @@ message TrainingInfoProto { ...@@ -239,20 +249,38 @@ message TrainingInfoProto {
// iteration to zero. // iteration to zero.
// //
// By default, this field is an empty graph and its evaluation does not // By default, this field is an empty graph and its evaluation does not
// produce any output. // produce any output. Thus, no initializer would be changed by default.
optional GraphProto initialization = 1; optional GraphProto initialization = 1;
// This field represents a training algorithm step. Given required inputs, // This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's // it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this graph contains loss node, gradient node, // initializer lists. In general, this field contains loss node, gradient node,
// optimizer node, increment of iteration count, and some calls to the inference // optimizer node, increment of iteration count.
// graph.
// //
// The field algorithm.node is the only place the user can use GraphCall // An execution of the training algorithm step is performed by executing the
// operator. The only callable graph is the one stored in ModelProto.graph. // graph obtained by combining the inference graph (namely "ModelProto.graph")
// and the "algorithm" graph. That is, the actual the actual
// input/initializer/output/node/value_info/sparse_initializer list of
// the training graph is the concatenation of
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
// in that order. This combined graph must satisfy the normal ONNX conditions.
// Now, let's provide a visualization of graph combination for clarity.
// Let the inference graph (i.e., "ModelProto.graph") be
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
// and the "algorithm" graph be
// tensor_d -> Add -> tensor_e
// The combination process results
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
//
// Notice that an input of a node in the "algorithm" graph may reference the
// output of a node in the inference graph (but not the other way round). Also, inference
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
// can always be run independently without training information.
// //
// By default, this field is an empty graph and its evaluation does not // By default, this field is an empty graph and its evaluation does not
// produce any output. // produce any output. Evaluating the default training step never
// update any initializers.
optional GraphProto algorithm = 2; optional GraphProto algorithm = 2;
// This field specifies the bindings from the outputs of "initialization" to // This field specifies the bindings from the outputs of "initialization" to
...@@ -284,23 +312,16 @@ message TrainingInfoProto { ...@@ -284,23 +312,16 @@ message TrainingInfoProto {
// be multiple key-value pairs in "update_binding". // be multiple key-value pairs in "update_binding".
// //
// The initializers appears as keys in "update_binding" are considered // The initializers appears as keys in "update_binding" are considered
// mutable and globally-visible variables. This implies some behaviors // mutable variables. This implies some behaviors
// as described below. // as described below.
// //
// 1. We have only unique keys in all "update_binding"s so that two global // 1. We have only unique keys in all "update_binding"s so that two
// variables may not have the same name. This ensures that one // variables may not have the same name. This ensures that one
// global variable is assigned up to once. // variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or // 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer". // "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm". // 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
// 4. If an optional input of a graph is omitted when using GraphCall, the // 4. Mutable variables are initialized to the value specified by the
// global variable with the same name may be used.
// 5. When using GraphCall, the users always can pass values to optional
// inputs of the called graph even if the associated initializers appears
// as keys in "update_binding"s.
// 6. The graphs in TrainingInfoProto's can use global variables as
// their operator inputs.
// 7. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by // corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
// //
...@@ -375,13 +396,31 @@ message ModelProto { ...@@ -375,13 +396,31 @@ message ModelProto {
// //
// If this field is empty, the training behavior of the model is undefined. // If this field is empty, the training behavior of the model is undefined.
repeated TrainingInfoProto training_info = 20; repeated TrainingInfoProto training_info = 20;
// A list of function protos local to the model.
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard opserator sets are given higher priotity or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto and other model local FunctionProtos.
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
// or by 2 FunctionProtos then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same for every node in the function body.
//
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated FunctionProto functions = 25;
}; };
// StringStringEntryProto follows the pattern for cross-proto-version maps. // StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps // See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto { message StringStringEntryProto {
optional string key = 1; optional string key = 1;
optional string value= 2; optional string value = 2;
}; };
message TensorAnnotation { message TensorAnnotation {
...@@ -409,8 +448,9 @@ message GraphProto { ...@@ -409,8 +448,9 @@ message GraphProto {
optional string name = 2; // namespace Graph optional string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph. // A list of named tensor values, used to specify constant inputs of the graph.
// Each TensorProto entry must have a distinct name (within the list) that // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
// MAY also appear in the input list. // The name MUST be unique across both initializer and sparse_initializer,
// but the name MAY also appear in the input list.
repeated TensorProto initializer = 5; repeated TensorProto initializer = 5;
// Initializers (see above) stored in sparse format. // Initializers (see above) stored in sparse format.
...@@ -433,13 +473,8 @@ message GraphProto { ...@@ -433,13 +473,8 @@ message GraphProto {
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14; repeated TensorAnnotation quantization_annotation = 14;
// DO NOT USE the following fields, they were deprecated from earlier versions. reserved 3, 4, 6 to 9;
// repeated string input = 3; reserved "ir_version", "producer_version", "producer_tag", "domain";
// repeated string output = 4;
// optional int64 ir_version = 6;
// optional int64 producer_version = 7;
// optional string producer_tag = 8;
// optional string domain = 9;
} }
// Tensors // Tensors
...@@ -474,6 +509,17 @@ message TensorProto { ...@@ -474,6 +509,17 @@ message TensorProto {
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16; BFLOAT16 = 16;
// Non-IEEE floating-point format based on papers
// FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
// 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
// The computation usually happens inside a block quantize / dequantize
// fused by the runtime.
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
// Future extensions go here. // Future extensions go here.
} }
...@@ -507,11 +553,11 @@ message TensorProto { ...@@ -507,11 +553,11 @@ message TensorProto {
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64. // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true]; repeated float float_data = 4 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, and float16 values // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
// float16 values must be bit-wise converted to an uint16_t prior // float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer. // to writing to the buffer.
// When this field is present, the data_type field MUST be // When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated int32 int32_data = 5 [packed = true]; repeated int32 int32_data = 5 [packed = true];
// For strings. // For strings.
...@@ -589,6 +635,8 @@ message TensorProto { ...@@ -589,6 +635,8 @@ message TensorProto {
message SparseTensorProto { message SparseTensorProto {
// The sequence of non-default values are encoded as a tensor of shape [NNZ]. // The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors. // The default-value is zero for numeric tensors, and empty-string for string tensors.
// values must have a non-empty name present which serves as a name for SparseTensorProto
// when used in sparse_initializer list.
optional TensorProto values = 1; optional TensorProto values = 1;
// The indices of the non-default values, which may be stored in one of two formats. // The indices of the non-default values, which may be stored in one of two formats.
...@@ -619,7 +667,7 @@ message TensorShapeProto { ...@@ -619,7 +667,7 @@ message TensorShapeProto {
// Standard denotation can optionally be used to denote tensor // Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure // dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor. // that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations. // for pre-defined dimension denotations.
optional string denotation = 3; optional string denotation = 3;
}; };
...@@ -656,6 +704,23 @@ message TypeProto { ...@@ -656,6 +704,23 @@ message TypeProto {
optional TypeProto value_type = 2; optional TypeProto value_type = 2;
}; };
// wrapper for Tensor, Sequence, or Map
message Optional {
// The type and optional shape of the element wrapped.
// This field MUST be present for this version of the IR.
// Possible values correspond to OptionalProto.DataType enum
optional TypeProto elem_type = 1;
};
message SparseTensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
optional int32 elem_type = 1;
optional TensorShapeProto shape = 2;
}
oneof value { oneof value {
// The type of a tensor. // The type of a tensor.
...@@ -672,11 +737,18 @@ message TypeProto { ...@@ -672,11 +737,18 @@ message TypeProto {
// The type of a map. // The type of a map.
Map map_type = 5; Map map_type = 5;
// The type of an optional.
Optional optional_type = 9;
// Type of the sparse tensor
SparseTensor sparse_tensor_type = 8;
} }
// An optional denotation can be used to denote the whole // An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is // type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations. // for pre-defined type denotations.
optional string denotation = 6; optional string denotation = 6;
} }
...@@ -696,7 +768,67 @@ message OperatorSetIdProto { ...@@ -696,7 +768,67 @@ message OperatorSetIdProto {
optional int64 version = 2; optional int64 version = 2;
} }
// Operator/function status.
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}
message FunctionProto {
// The name of the function, similar usage of op_type in OperatorProto.
// Combined with FunctionProto.domain, this forms the unique identity of
// the FunctionProto.
optional string name = 1;
// Deprecated since IR Version 8
// optional int64 since_version = 2;
reserved 2;
reserved "since_version";
// Deprecated since IR Version 8
// optional OperatorStatus status = 3;
reserved 3;
reserved "status";
// The inputs and outputs of the function.
repeated string input = 4;
repeated string output = 5;
// The attribute parameters of the function.
// It is for function parameters without default values.
repeated string attribute = 6;
// The attribute protos of the function.
// It is for function attributes with default values.
// A function attribute shall be represented either as
// a string attribute or an AttributeProto, not both.
repeated AttributeProto attribute_proto = 11;
// The nodes in the function.
repeated NodeProto node = 7;
// A human-readable documentation for this function. Markdown is allowed.
optional string doc_string = 8;
// The OperatorSets this function body (graph) relies on.
//
// All nodes in the function body (graph) will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets. This means at most one version can be relied
// for one domain.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
// and ModelProto then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same.
repeated OperatorSetIdProto opset_import = 9;
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
optional string domain = 10;
}
// For using protobuf-lite // For using protobuf-lite
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;
\ No newline at end of file
...@@ -34,7 +34,9 @@ ...@@ -34,7 +34,9 @@
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <onnx.pb.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const ...@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
case onnx::AttributeProto::TENSORS: case onnx::AttributeProto::TENSORS:
case onnx::AttributeProto::SPARSE_TENSOR: case onnx::AttributeProto::SPARSE_TENSOR:
case onnx::AttributeProto::SPARSE_TENSORS: case onnx::AttributeProto::SPARSE_TENSORS:
case onnx::AttributeProto::TYPE_PROTOS:
case onnx::AttributeProto::TYPE_PROTO:
case onnx::AttributeProto::GRAPHS: return {}; case onnx::AttributeProto::GRAPHS: return {};
} }
MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type())); MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type()));
...@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const ...@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data()); return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data()); case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::FLOAT8E4M3FNUZ: {
std::vector<int32_t> data_int32(t.int32_data().begin(), t.int32_data().end());
std::vector<migraphx::fp8::fp8e4m3fnuz> data_fp8;
std::transform(data_int32.begin(),
data_int32.end(),
std::back_inserter(data_fp8),
[](float raw_val) { return migraphx::fp8::fp8e4m3fnuz{raw_val}; });
return create_literal(shape::fp8e4m3fnuz_type, dims, data_fp8);
}
case onnx::TensorProto::FLOAT8E5M2FNUZ:
case onnx::TensorProto::FLOAT8E5M2:
case onnx::TensorProto::FLOAT8E4M3FN:
case onnx::TensorProto::UNDEFINED: case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::STRING: case onnx::TensorProto::STRING:
case onnx::TensorProto::COMPLEX64: case onnx::TensorProto::COMPLEX64:
...@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype) ...@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type; case 11: return shape::double_type;
case 12: return shape::uint32_type; case 12: return shape::uint32_type;
case 13: return shape::uint64_type; case 13: return shape::uint64_type;
case 18: return shape::fp8e4m3fnuz_type;
case 14:
case 15:
case 16:
case 17:
case 19:
case 20:
default: { default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported"); MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
} }
......
...@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
// use literal. The array populated by random_uniform may have any shape, as long its // use literal. The array populated by random_uniform may have any shape, as long its
// number of elements is batch_size * sample_size . // number of elements is batch_size * sample_size .
size_t batch_size = s0.lens().front(); size_t batch_size = s0.lens().front();
auto rand_dummy = info.add_literal( auto rand_dummy = info.add_literal(migraphx::literal{
migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}}); migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
std::vector<float>(batch_size * sample_size)});
randoms = randoms =
info.add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy); info.add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
} }
......
...@@ -22,14 +22,8 @@ ...@@ -22,14 +22,8 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/pooling.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,68 +33,14 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -39,68 +33,14 @@ struct parse_pooling : op_parser<parse_pooling>
{ {
std::vector<op_desc> operators() const std::vector<op_desc> operators() const
{ {
return {{"AveragePool", "average"}, return {
{"AveragePool", "average"},
{"GlobalAveragePool", "average"}, {"GlobalAveragePool", "average"},
{"GlobalMaxPool", "max"}, {"GlobalMaxPool", "max"},
{"MaxPool", "max"}, {"MaxPool", "max"},
{"LpPool", "lpnorm"}, {"LpPool", "lpnorm"},
{"GlobalLpPool", "lpnorm"}}; {"GlobalLpPool", "lpnorm"},
} };
value handle_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values) const
{
auto kdims = in_shape.ndim() - 2;
if(starts_with(opd.onnx_name, "Global"))
{
// if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
}
else
{
// works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
}
if(contains(info.attributes, "ceil_mode"))
{
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
// lp_order attribute
if(contains(info.attributes, "p"))
{
values["lp_order"] = info.attributes.at("p").i();
}
// ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
return values;
} }
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
...@@ -108,144 +48,8 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -108,144 +48,8 @@ struct parse_pooling : op_parser<parse_pooling>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
std::string mode = opd.op_name; return add_pooling_op(opd, std::move(info), args[0]);
const std::unordered_map<std::string, op::pooling_mode> mode_map = { };
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
}
if(paddings.size() != 2 * kdims)
{
paddings.resize(kdims * 2);
std::fill_n(paddings.begin(), 2 * kdims, 0);
}
if(values["padding"].size() != kdims)
{
values["padding"].resize(kdims);
std::fill_n(values["padding"].begin(), kdims, 0);
}
if(values["stride"].size() != kdims)
{
values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings;
// TODO: add parsing for dilations
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
auto auto_pad = to_upper(info.attributes["auto_pad"].s());
// don't use the given padding sizes, if any
// values["padding"].clear();
if(in_shape.dynamic())
{
// set padding_mode to trigger auto padding at runtime
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
std::vector<size_t>(in_shape.ndim() - 2, 1),
in_shape.lens(),
paddings);
values["padding"] = paddings;
// default padding_mode indicates that padding sizes are not calculated dynamically
values["padding_mode"] = migraphx::op::padding_mode_t::default_;
}
}
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(not slice_start.empty())
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()});
// make an op just to get its output shape
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information
slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2,
out_lens.end(),
slice_start.begin(),
slice_end.begin(),
[](auto i, auto j) { return i + j; });
}
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
op.from_value(values);
auto l1 = info.add_instruction(op, l0);
if(not slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}),
l1);
}
return l1;
}
}; };
} // namespace onnx } // namespace onnx
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -36,90 +37,56 @@ namespace onnx { ...@@ -36,90 +37,56 @@ namespace onnx {
/* /*
********************************************************************************* *********************************************************************************
* Reference: see QLinearGlobalAveragePool in * * Reference: see QLinearAveragePool and QLinearGlobalAveragePool in *
* github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md * * github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
********************************************************************************* *********************************************************************************
*/
QLinearGlobalAveragePool consumes an input tensor X and applies struct parse_qlinearpooling : op_parser<parse_qlinearpooling>
Average pooling across the values in the same channel. This is
equivalent to AveragePool with kernel size equal to the spatial
dimension of input tensor. Input is of type uint8_t or int8_t.
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
Attributes
channels_last : int
Inputs
X : T
Input data tensor from the previous operator; According to channels_last, dimensions for image case
are (N x C x H x W), or (N x H x W x C) where N is the batch size, C is the number of channels, and
H and W are the height and the width of the data. For non image case, the dimensions are in the form
of (N x C x D1 x D2 ... Dn), or (N x D1 X D2 ... Dn x C) where N is the batch size.
x_scale : tensor(float)
Scale of quantized input 'X'. It must be a scalar.
x_zero_point : T
Zero point tensor for input 'X'. It must be a scalar.
y_scale : tensor(float)
Scale of quantized output 'Y'. It must be a scalar.
y_zero_point : T
Zero point tensor for output 'Y'. It must be a scalar.
Outputs
Y : T
Output data tensor from pooling across the input tensor. The output tensor has the same rank as the
input. with the N and C value keep it value, while the other dimensions are all 1. Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to signed/unsigned int8 tensors.
*/
struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool>
{ {
std::vector<op_desc> operators() const { return {{"QLinearGlobalAveragePool"}}; } std::vector<op_desc> operators() const
// basic type checking for QLinearGlobalAveragePool Operator
void check_inputs(const std::vector<instruction_ref>& args) const
{ {
if(args.size() < 5) return {{"QLinearGlobalAveragePool", "average"}, {"QLinearAveragePool", "average"}};
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: missing inputs"); }
void check_inputs(const op_desc& opd, const std::vector<instruction_ref>& args) const
{
const auto& in_x = args[0]; const auto& in_x = args[0];
const auto& zero_pt_x = args[2]; const auto onnx_name = opd.onnx_name;
const auto& zero_pt_y = args[4];
if(in_x->get_shape().ndim() <= 2) if(in_x->get_shape().ndim() <= 2)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: input dimensions too small"); MIGRAPHX_THROW(onnx_name + ": input dimensions too small");
auto type_x = in_x->get_shape().type(); auto type_x = in_x->get_shape().type();
if(type_x != migraphx::shape::int8_type and type_x != migraphx::shape::uint8_type) if(type_x != migraphx::shape::int8_type and type_x != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: unsupported input type"); MIGRAPHX_THROW(onnx_name + ": unsupported input type");
const auto& zero_pt_x = args[2];
if(type_x != zero_pt_x->get_shape().type()) if(type_x != zero_pt_x->get_shape().type())
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: mismatched type: input zero point"); MIGRAPHX_THROW(onnx_name + ": mismatched type: input zero point");
if(args.size() == 5)
{
const auto& zero_pt_y = args[4];
if(type_x != zero_pt_y->get_shape().type()) if(type_x != zero_pt_y->get_shape().type())
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: mismatched type: output zero point"); MIGRAPHX_THROW(onnx_name + ": mismatched type: output zero point");
}
} }
instruction_ref parse(const op_desc& /* opd */, instruction_ref parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{
if(contains(info.attributes, "channel_last"))
{ {
int channels_last = int channels_last =
parser.parse_value(info.attributes.at("channels_last")).template at<int>(); parser.parse_value(info.attributes.at("channels_last")).template at<int>();
if(channels_last != 0) if(channels_last != 0)
MIGRAPHX_THROW( MIGRAPHX_THROW(opd.onnx_name + ": channels_last (N x D1..Dn x C) is not supported");
"QLINEARGLOBALAVERAGEPOOL: channels_last (N x D1..Dn x C) is not supported"); }
check_inputs(args); check_inputs(opd, args);
// Input: X // Input: X
...@@ -128,21 +95,18 @@ struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool ...@@ -128,21 +95,18 @@ struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool
const auto& zero_pt_x = args[2]; const auto& zero_pt_x = args[2];
auto dquant_x = bcast_qdq_instr("dequantizelinear", in_x, scale_x, zero_pt_x, info); auto dquant_x = bcast_qdq_instr("dequantizelinear", in_x, scale_x, zero_pt_x, info);
// Output Y = globalaveragepool(X) // Output Y = pooling_op(X)
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = in_x->get_shape().lens();
std::vector<size_t> lengths(lens.begin() + 2, lens.end());
op.lengths = lengths;
op.padding = std::vector<size_t>(lens.size());
auto out_y = info.add_instruction(op, dquant_x);
const auto& scale_y = args[3]; auto out_y = add_pooling_op(opd, info, dquant_x);
const auto& zero_pt_y = args[4];
auto out_quant_y = bcast_qdq_instr("quantizelinear", out_y, scale_y, zero_pt_y, info); const auto& in_scale_y = args[3];
// zero_pt for Y is supplied as the last optional argument..
if(args.size() == 5)
return (bcast_qdq_instr("quantizelinear", out_y, in_scale_y, args[4], info));
return out_quant_y; // if no zero_pt: just broadcast the scale..
auto bcast_scale_y = bcast_scalar_instr(out_y->get_shape(), in_scale_y, info);
return (info.add_instruction(migraphx::make_op("quantizelinear"), out_y, bcast_scale_y));
} }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
/*
*********************************************************************************
* Reference: see QLinearSigmoid, QLinearLeakyRelu in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
com.microsoft.QLinearSigmoid
QLinearSigmoid takes quantized input data (Tensor), and quantize parameter for output, and produces
one output data (Tensor) where the function f(x) = quantize(Sigmoid(dequantize(x))), is applied to
the data tensor elementwise. Where the function Sigmoid(x) = 1 / (1 + exp(-x))
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
*****************************************************************************************************
com.microsoft.QLinearLeakyRelu
QLinearLeakyRelu takes quantized input data (Tensor), an argument alpha, and quantize parameter for
output, and produces one output data (Tensor) where the function f(x) = quantize(alpha *
dequantize(x)) for dequantize(x) < 0, f(x) = quantize(dequantize(x)) for dequantize(x) >= 0, is
applied to the data tensor elementwise.
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
Attributes
alpha : float
Coefficient of leakage.
******************************************************************************************************
Generic input layout of QLinear unary operators:
Inputs (4 - 5)
X : T
Input tensor
X_scale : tensor(float)
Input X's scale. It's a scalar, which means a per-tensor/layer quantization.
X_zero_point (optional) : T
Input X's zero point. Default value is 0 if it's not specified. It's a scalar, which means a
per-tensor/layer quantization.
Y_scale : tensor(float) Output Y's scale. It's a scalar, which means
a per-tensor/layer quantization.
Y_zero_point (optional) : T Output Y's zero point. Default value is
0 if it's not specified. It's a scalar, which means a per-tensor/layer quantization.
Outputs
Y : T
Output tensor
Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to 8 bit tensors.
*/
struct parse_qlinearunary : op_parser<parse_qlinearunary>
{
std::vector<op_desc> operators() const
{
return {{"QLinearSigmoid", "sigmoid"}, {"QLinearLeakyRelu", "leaky_relu"}};
}
void check_inputs(const op_desc& opd, const std::vector<instruction_ref>& args) const
{
if(args.size() < 4)
MIGRAPHX_THROW(opd.op_name + ": missing inputs");
const auto& in_x = args[0];
auto sh_x = in_x->get_shape();
auto type_x = sh_x.type();
if(type_x != migraphx::shape::int8_type and type_x != migraphx::shape::uint8_type)
MIGRAPHX_THROW(opd.op_name + ": unsupported input type");
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
check_inputs(opd, args);
// X
const auto& in_x = args[0];
const auto& in_scale_x = args[1];
const auto& in_zero_pt_x = args[2];
auto dquant_x = bcast_qdq_instr("dequantizelinear", in_x, in_scale_x, in_zero_pt_x, info);
// Y = (op(dequantizelinear(x))
auto op = parser.load(opd.op_name, info);
auto y = info.add_instruction(op, dquant_x);
const auto& in_scale_y = args[3];
// zero_pt for Y is supplied as the last optional argument..
if(args.size() == 5)
return (bcast_qdq_instr("quantizelinear", y, in_scale_y, args[4], info));
// if no zero_pt: just broadcast the scale..
auto bcast_scale_sigm = bcast_scalar_instr(y->get_shape(), in_scale_y, info);
return (info.add_instruction(migraphx::make_op("quantizelinear"), y, bcast_scale_sigm));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -39,15 +39,17 @@ struct parse_scatternd : op_parser<parse_scatternd> ...@@ -39,15 +39,17 @@ struct parse_scatternd : op_parser<parse_scatternd>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const std::vector<instruction_ref>& args) const
{ {
std::string reduction = "none";
if(contains(info.attributes, "reduction")) if(contains(info.attributes, "reduction"))
{ {
if(info.attributes.at("reduction").s() == "add") reduction = info.attributes.at("reduction").s();
return info.add_instruction(migraphx::make_op("scatternd_add"), args); if(not contains({"none", "add", "mul", "min", "max"}, reduction))
if(info.attributes.at("reduction").s() == "mul") {
return info.add_instruction(migraphx::make_op("scatternd_mul"), args); MIGRAPHX_THROW("PARSE_SCATTERND: unsupported reduction mode " + reduction);
}
} }
return info.add_instruction(migraphx::make_op("scatternd_none"), args); return info.add_instruction(migraphx::make_op("scatternd_" + reduction), args);
} }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <optional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
// generate unique output stream y, given input stream x;
//
// case unsorted:
// input x: [2, 1, 1, 3, 4, 3], attr_sorted = 0;
// output(s):
// y: [2, 1, 3, 4] --- the unique output
// y_indices: [0, 1, 3, 4] --- first incidence, in terms of indices of x
// x_rev_indices: [0, 1, 1, 2, 3, 2] --- x seen in terms of indices of y
// y_count: [1, 2, 2, 1] -- count at each y_index. sum = len(x)
//
// case sorted:
// input x: [2, 1, 1, 3, 4, 3], attr_sorted = 1;
// output(s):
// y: [1, 2, 3, 4] --- the unique output
// y_indices: [1, 0, 3, 4] --- first incidence, in terms of indices of x
// x_rev_indices: [1, 0, 0, 2, 3, 2] --- x seen in terms of indices of y
// y_count: [2, 1, 2, 1] -- count at each y_index. sum = len(x)
struct parse_unique : op_parser<parse_unique>
{
std::vector<op_desc> operators() const { return {{"Unique"}}; }
std::vector<instruction_ref> parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int64_t sorted = 1; // default = sorted.
if(contains(info.attributes, "sorted"))
sorted = parser.parse_value(info.attributes.at("sorted")).at<int>();
std::optional<int64_t> axis;
if(contains(info.attributes, "axis"))
{
auto n_dim = args[0]->get_shape().ndim();
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
axis = tune_axis(n_dim, *axis, opd.op_name);
}
migraphx::argument data_arg = args.back()->eval();
auto opr = axis ? migraphx::make_op("unique", {{"axis", *axis}, {"sorted", sorted}})
: migraphx::make_op("unique", {{"sorted", sorted}});
auto u_opr = info.add_instruction(opr, args.at(0));
auto i_y = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), u_opr);
auto i_y_idx = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), u_opr);
auto i_x_idx = info.add_instruction(make_op("get_tuple_elem", {{"index", 2}}), u_opr);
auto i_count = info.add_instruction(make_op("get_tuple_elem", {{"index", 3}}), u_opr);
return {i_y, i_y_idx, i_x_idx, i_count};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
value handle_pooling_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values)
{
auto kdims = in_shape.ndim() - 2;
if(starts_with(opd.onnx_name, "Global") or starts_with(opd.onnx_name, "QLinearGlobal"))
{
// if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
}
else
{
// works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
}
if(contains(info.attributes, "ceil_mode"))
{
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
if(contains(info.attributes, "dilations"))
{
values["dilations"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilations"]));
check_attr_sizes(
kdims, values["dilations"].size(), "PARSE_POOLING: inconsistent dilations");
}
// lp_order attribute
if(contains(info.attributes, "p"))
{
values["lp_order"] = info.attributes.at("p").i();
}
// ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
return values;
}
instruction_ref add_pooling_op(const op_desc& opd, onnx_parser::node_info info, instruction_ref l0)
{
std::string mode = opd.op_name;
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_pooling_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
}
if(paddings.size() != 2 * kdims)
{
paddings.resize(kdims * 2);
std::fill_n(paddings.begin(), 2 * kdims, 0);
}
if(values["padding"].size() != kdims)
{
values["padding"].resize(kdims);
std::fill_n(values["padding"].begin(), kdims, 0);
}
if(values["stride"].size() != kdims)
{
values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1);
}
if(values["dilations"].size() != kdims)
{
values["dilations"].resize(kdims);
std::fill_n(values["dilations"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings;
// TODO: add parsing for dilations
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
auto auto_pad = to_upper(info.attributes["auto_pad"].s());
// don't use the given padding sizes, if any
// values["padding"].clear();
if(in_shape.dynamic())
{
// set padding_mode to trigger auto padding at runtime
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
values["dilations"].to_vector<std::size_t>(),
in_shape.lens(),
paddings);
values["padding"] = paddings;
// default padding_mode indicates that padding sizes are not calculated dynamically
values["padding_mode"] = migraphx::op::padding_mode_t::default_;
}
}
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(not slice_start.empty())
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()});
// make an op just to get its output shape
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information
slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2,
out_lens.end(),
slice_start.begin(),
slice_end.begin(),
[](auto i, auto j) { return i + j; });
}
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
op.from_value(values);
auto l1 = info.add_instruction(op, l0);
if(not slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}), l1);
}
return l1;
}
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -35,25 +35,14 @@ ...@@ -35,25 +35,14 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(module& m) const static void replace_with_reduce(module& m, instruction_ref ins)
{ {
for(auto ins : iterator_for(m))
{
if(ins->name() != "pooling")
continue;
if(ins->inputs().empty())
continue;
auto&& s = ins->inputs().front()->get_shape(); auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue;
if(not std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
continue;
auto lens = s.lens(); auto lens = s.lens();
if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue;
std::vector<std::int64_t> axes(lens.size() - 2); std::vector<std::int64_t> axes(lens.size() - 2);
std::iota(axes.begin(), axes.end(), 2); std::iota(axes.begin(), axes.end(), 2);
// average pooling // average pooling
if(op.mode == op::pooling_mode::average) if(op.mode == op::pooling_mode::average)
{ {
...@@ -64,6 +53,131 @@ void rewrite_pooling::apply(module& m) const ...@@ -64,6 +53,131 @@ void rewrite_pooling::apply(module& m) const
{ {
m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs()); m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs());
} }
}
static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins)
{
// TODO remove this when MIOpen supports dilated pooling
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
// Ignore N, C axes
std::vector<size_t> dims = {s.lens().cbegin() + 2, s.lens().cend()};
bool default_padding =
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
if(not default_padding)
{
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// We need to pad both ends
dims[idx] += op.padding.at(idx) * 2;
}
}
std::vector<size_t> kernels = op.lengths;
std::vector<size_t> strides = op.stride;
std::vector<size_t> dilations = op.dilations;
std::vector<std::vector<int>> axis_indices;
axis_indices.resize(dims.size());
for(auto idx{0}; idx < dims.size(); ++idx)
{
// Only consider if iw fits into the window
for(size_t stride{0}; stride < dims.at(idx) - dilations.at(idx) * (kernels.at(idx) - 1);
stride += strides.at(idx))
{
for(size_t step{0}; step < kernels.at(idx); ++step)
{
axis_indices.at(idx).push_back(stride + dilations.at(idx) * step);
}
}
}
auto elements = ins->inputs().front();
if(not default_padding)
{
// Pad supports asym, we need to provide both ends
std::vector<size_t> padding(2 * s.lens().size(), 0);
// Format will be e.g {N, C, P1, P2, N, C, P1, P2}
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// Ignore N, C axes
padding.at(2 + idx) = op.padding.at(idx);
padding.at(2 + idx + s.lens().size()) = op.padding.at(idx);
}
// Default value needed for Max pooling
elements = m.insert_instruction(
ins,
make_op("pad", {{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
elements);
}
for(auto idx{0}; idx < axis_indices.size(); ++idx)
{
migraphx::shape s_indices{migraphx::shape::int32_type, {axis_indices.at(idx).size()}};
auto indices = m.add_literal(migraphx::literal{s_indices, axis_indices.at(idx)});
elements = m.insert_instruction(
ins, make_op("gather", {{"axis", idx + 2 /*ignore N,C*/}}), elements, indices);
}
// Ignore padding
std::vector<size_t> new_padding(kernels.size(), 0);
// The kernel window elements are places next to each other. E.g. {x1, y1, x2, y2, ...}
// We need to skip them to not overlap
std::vector<size_t> new_strides(kernels);
// Ignore dilations
std::vector<size_t> new_dilations(kernels.size(), 1);
m.replace_instruction(ins,
make_op("pooling",
{{"mode", op.mode},
{"padding", new_padding},
{"stride", new_strides},
{"lengths", kernels},
{"dilations", new_dilations}}),
elements);
}
void rewrite_pooling::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "pooling")
continue;
if(ins->inputs().empty())
continue;
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
bool same_kernel_as_shape = std::equal(
s.lens().cbegin() + 2, s.lens().cend(), op.lengths.cbegin(), op.lengths.cend());
bool default_strides =
std::all_of(op.stride.cbegin(), op.stride.cend(), [](auto i) { return i == 1; });
bool default_padding =
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
bool default_dilations =
std::all_of(op.dilations.cbegin(), op.dilations.cend(), [](auto i) { return i == 1; });
if(same_kernel_as_shape and default_strides and default_padding and default_dilations)
{
replace_with_reduce(m, ins);
}
else if(not default_dilations)
{
// Dilated AvgPool with padding is not supported
if(not default_padding and op.mode == op::pooling_mode::average)
{
continue;
}
auto size =
std::accumulate(s.lens().cbegin(), s.lens().cend(), 1, std::multiplies<size_t>());
// Can't handle too much size because of literal size
if(size > 100000)
{
continue;
}
replace_dilations_with_gather_pooling(m, ins);
}
} }
} }
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp> #include <migraphx/iterator.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/simple_par_for.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/dom_info.hpp> #include <migraphx/dom_info.hpp>
...@@ -461,7 +461,7 @@ struct stream_info ...@@ -461,7 +461,7 @@ struct stream_info
std::back_inserter(index_to_ins), std::back_inserter(index_to_ins),
[](auto&& it) { return it.first; }); [](auto&& it) { return it.first; });
par_for(concur_ins.size(), [&](auto ins_index, auto tid) { simple_par_for(concur_ins.size(), [&](auto ins_index, auto tid) {
auto merge_first = index_to_ins[ins_index]; auto merge_first = index_to_ins[ins_index];
assert(concur_ins.count(merge_first) > 0); assert(concur_ins.count(merge_first) > 0);
auto& merge_second = concur_ins.at(merge_first); auto& merge_second = concur_ins.at(merge_first);
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/simplify_dyn_ops.hpp> #include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
...@@ -65,8 +66,65 @@ struct find_static_2in_broadcasts ...@@ -65,8 +66,65 @@ struct find_static_2in_broadcasts
}; };
/** /**
* Simplify slice with variable `starts` and `ends` to the constant version if * Simplify slice with 2 inputs to the 1 input version if inputs[1] is constant.
* the `input_starts` and `input_ends` inputs are constant. * From:
* slice(data, constant_input); two attributes set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_2in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(2), match::arg(1)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
auto slice_op = any_cast<op::slice>(ins->get_operator());
auto set_attrs = slice_op.get_set_attributes();
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
if(set_attrs == op::slice::ends_axes)
{
// slice(data, starts)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_vec = slice_op.ends;
axes_vec = slice_op.axes;
}
else if(set_attrs == op::slice::starts_axes)
{
// slice(data, ends)
inputs.at(1)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
axes_vec = slice_op.axes;
}
else
{
// slice(data, axes)
inputs.at(1)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
ends_vec = slice_op.ends;
}
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
};
/**
* Simplify slice with 3 inputs to the 1 input version if inputs[1:2] are constant.
* From:
* slice(data, constant_input1, constant_input2); one attribute set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/ */
struct find_const_3in_slice struct find_const_3in_slice
{ {
...@@ -81,27 +139,51 @@ struct find_const_3in_slice ...@@ -81,27 +139,51 @@ struct find_const_3in_slice
{ {
auto ins = mr.result; auto ins = mr.result;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval(); auto slice_op = any_cast<op::slice>(ins->get_operator());
argument ends_arg = inputs.at(2)->eval(); auto set_attrs = slice_op.get_set_attributes();
if(not starts_arg.empty() and not ends_arg.empty())
{
std::vector<int64_t> starts_vec; std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec; std::vector<int64_t> ends_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); }); std::vector<int64_t> axes_vec;
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); }); if(set_attrs == op::slice::axes_only)
auto slice_val = ins->get_operator().to_value(); {
auto axes_vec = slice_val.at("axes").to_vector<int64_t>(); // slice(data, starts, ends)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
axes_vec = slice_op.axes;
}
else if(set_attrs == op::slice::ends_only)
{
// slice(data, starts, axes)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
ends_vec = slice_op.ends;
}
else
{
// slice(data, ends, axes)
inputs.at(1)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
}
m.replace_instruction( m.replace_instruction(
ins, ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}), make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0)); inputs.at(0));
} }
}
}; };
/** /**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if * Simplify slice with 4 inputs to the 1 input version if inputs[1:3] are constant.
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant. * From:
* slice(data, constant_starts, constant_ends, constant_axes)
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/ */
struct find_const_4in_slice struct find_const_4in_slice
{ {
...@@ -117,9 +199,9 @@ struct find_const_4in_slice ...@@ -117,9 +199,9 @@ struct find_const_4in_slice
{ {
auto ins = mr.result; auto ins = mr.result;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval(); argument starts_arg = inputs.at(1)->eval(false);
argument ends_arg = inputs.at(2)->eval(); argument ends_arg = inputs.at(2)->eval(false);
argument axes_arg = inputs.at(3)->eval(); argument axes_arg = inputs.at(3)->eval(false);
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty()) if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
{ {
std::vector<int64_t> starts_vec; std::vector<int64_t> starts_vec;
...@@ -179,6 +261,7 @@ struct find_static_dimensions_of ...@@ -179,6 +261,7 @@ struct find_static_dimensions_of
/** /**
* Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1 * Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1
* argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes. * argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes.
* This matcher can be generalized to matching reshape(data, static_shape_output_tensor).
* From: * From:
* x = allocate(constant_output_dims) -> reshape(data, x) * x = allocate(constant_output_dims) -> reshape(data, x)
* To: * To:
...@@ -207,14 +290,44 @@ struct find_const_alloc_reshapes ...@@ -207,14 +290,44 @@ struct find_const_alloc_reshapes
} }
}; };
/**
* Simplify allocate into fill operator that has constant output dimensions and constant value.
* The allocate into fill instructions is what is produced when parsing the ONNX
* ConstantOfShape operator. This replacement could be handled with propagate_constant, but
* would rather have the simplification happen earlier during compiling.
* This matcher can be generalized to matching fill(constant_value, static_shape_output_tensor).
* From:
* x = allocate(constant_ouptut_dims) -> fill(constant_value, x)
* To:
* literal
*/
struct find_const_alloc_fill
{
auto matcher() const
{
return match::name("fill")(match::arg(0)(match::is_constant()),
match::arg(1)(match::name("allocate")(match::is_constant())));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto fill_ins = mr.result;
auto fill_arg = fill_ins->eval(false);
auto l = m.add_literal(fill_arg.get_shape(), fill_arg.data());
m.replace_instruction(fill_ins, l);
}
};
void simplify_dyn_ops::apply(module& m) const void simplify_dyn_ops::apply(module& m) const
{ {
match::find_matches(m, match::find_matches(m,
find_static_dimensions_of{}, find_static_dimensions_of{},
find_const_alloc_reshapes{}, find_const_alloc_reshapes{},
find_static_2in_broadcasts{}, find_static_2in_broadcasts{},
find_const_2in_slice{},
find_const_3in_slice{}, find_const_3in_slice{},
find_const_4in_slice{}); find_const_4in_slice{},
find_const_alloc_fill{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -67,8 +67,8 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t) ...@@ -67,8 +67,8 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t)
case st::float_type: return dt::f32; case st::float_type: return dt::f32;
case st::int32_type: return dt::s32; case st::int32_type: return dt::s32;
case st::int8_type: return dt::s8; case st::int8_type: return dt::s8;
case st::uint8_type: case st::uint8_type: return dt::u8;
case st::fp8e4m3fnuz_type: return dt::u8; case st::fp8e4m3fnuz_type: MIGRAPHX_THROW("fp8e4m3fnuz unsupported in DNNL");
default: MIGRAPHX_THROW("Unsupported data type"); default: MIGRAPHX_THROW("Unsupported data type");
} }
} }
......
...@@ -340,7 +340,6 @@ struct cpu_apply ...@@ -340,7 +340,6 @@ struct cpu_apply
{"reduce_min", "reduction_min"}, {"reduce_min", "reduction_min"},
{"reduce_sum", "reduction_sum"}, {"reduce_sum", "reduction_sum"},
}); });
extend_op("concat", "dnnl::concat"); extend_op("concat", "dnnl::concat");
extend_op("contiguous", "dnnl::reorder"); extend_op("contiguous", "dnnl::reorder");
extend_op("convolution", "dnnl::convolution"); extend_op("convolution", "dnnl::convolution");
...@@ -376,6 +375,12 @@ struct cpu_apply ...@@ -376,6 +375,12 @@ struct cpu_apply
// Apply these operators first so the inputs can be const folded // Apply these operators first so the inputs can be const folded
for(auto it : iterator_for(*modl)) for(auto it : iterator_for(*modl))
{ {
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
continue;
if(it->name() == "pow") if(it->name() == "pow")
{ {
apply_pow(it); apply_pow(it);
...@@ -383,6 +388,12 @@ struct cpu_apply ...@@ -383,6 +388,12 @@ struct cpu_apply
} }
for(auto it : iterator_for(*modl)) for(auto it : iterator_for(*modl))
{ {
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
continue;
if(it->name() == "pooling") if(it->name() == "pooling")
{ {
apply_pooling(it); apply_pooling(it);
......
...@@ -34,23 +34,32 @@ namespace migraphx { ...@@ -34,23 +34,32 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling> struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_v2_forward, op::pooling>
{ {
std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; } std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; }
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const dnnl::pooling_v2_forward::desc
get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg; : dnnl::algorithm::pooling_avg;
auto kdims = op.kdims(); auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims); std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end()); std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
// Note: It is not documented, but the default dilation seems to be 0 instead of 1.
// We need to offset dilations with -1.
std::vector<size_t> dilations;
std::transform(op.dilations.cbegin(),
op.dilations.cend(),
std::back_inserter(dilations),
[](size_t d) { return d - 1; });
return {dnnl::prop_kind::forward_inference, return {dnnl::prop_kind::forward_inference,
algo, algo,
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride), to_dnnl_dims(op.stride),
to_dnnl_dims(op.lengths), to_dnnl_dims(op.lengths),
to_dnnl_dims(dilations),
to_dnnl_dims(padding_l), to_dnnl_dims(padding_l),
to_dnnl_dims(padding_r)}; to_dnnl_dims(padding_r)};
} }
......
...@@ -126,7 +126,6 @@ add_library(migraphx_gpu ...@@ -126,7 +126,6 @@ add_library(migraphx_gpu
fuse_ck.cpp fuse_ck.cpp
fuse_mlir.cpp fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp
gemm_impl.cpp gemm_impl.cpp
hip.cpp hip.cpp
kernel.cpp kernel.cpp
...@@ -140,7 +139,6 @@ add_library(migraphx_gpu ...@@ -140,7 +139,6 @@ add_library(migraphx_gpu
nonzero.cpp nonzero.cpp
pack_args.cpp pack_args.cpp
prefuse_ops.cpp prefuse_ops.cpp
pad.cpp
perfdb.cpp perfdb.cpp
pooling.cpp pooling.cpp
reverse.cpp reverse.cpp
...@@ -168,12 +166,10 @@ endfunction() ...@@ -168,12 +166,10 @@ endfunction()
register_migraphx_gpu_ops(hip_ register_migraphx_gpu_ops(hip_
argmax argmax
argmin argmin
gather
logsoftmax logsoftmax
loop loop
multinomial multinomial
nonzero nonzero
pad
prefix_scan_sum prefix_scan_sum
reverse reverse
scatter scatter
......
...@@ -194,7 +194,7 @@ struct hiprtc_program ...@@ -194,7 +194,7 @@ struct hiprtc_program
}; };
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file> srcs, std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file> srcs,
std::string params, const std::string& params,
const std::string& arch) const std::string& arch)
{ {
hiprtc_program prog(std::move(srcs)); hiprtc_program prog(std::move(srcs));
...@@ -238,8 +238,9 @@ bool hip_has_flags(const std::vector<std::string>& flags) ...@@ -238,8 +238,9 @@ bool hip_has_flags(const std::vector<std::string>& flags)
} }
} }
std::vector<std::vector<char>> std::vector<std::vector<char>> compile_hip_src(const std::vector<src_file>& srcs,
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) const std::string& params,
const std::string& arch)
{ {
std::vector<hiprtc_src_file> hsrcs{srcs.begin(), srcs.end()}; std::vector<hiprtc_src_file> hsrcs{srcs.begin(), srcs.end()};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{})) if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
...@@ -281,13 +282,13 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -281,13 +282,13 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(fs::exists(out)) if(fs::exists(out))
return {read_buffer(out.string())}; return {read_buffer(out.string())};
} }
return compile_hip_src_with_hiprtc(std::move(hsrcs), std::move(params), arch); return compile_hip_src_with_hiprtc(std::move(hsrcs), params, arch);
} }
#else // MIGRAPHX_USE_HIPRTC #else // MIGRAPHX_USE_HIPRTC
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file>, // NOLINT std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file>, // NOLINT
std::string, // NOLINT const std::string&, // NOLINT
const std::string&) const std::string&)
{ {
MIGRAPHX_THROW("Not using hiprtc"); MIGRAPHX_THROW("Not using hiprtc");
...@@ -316,29 +317,15 @@ src_compiler assemble(src_compiler compiler) ...@@ -316,29 +317,15 @@ src_compiler assemble(src_compiler compiler)
return compiler; return compiler;
} }
std::vector<std::vector<char>> std::vector<std::vector<char>> compile_hip_src(const std::vector<src_file>& srcs,
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) const std::string& params,
const std::string& arch)
{ {
assert(not srcs.empty()); assert(not srcs.empty());
if(not is_hip_clang_compiler()) if(not is_hip_clang_compiler())
MIGRAPHX_THROW("Unknown hip compiler: " MIGRAPHX_HIP_COMPILER); MIGRAPHX_THROW("Unknown hip compiler: " MIGRAPHX_HIP_COMPILER);
if(params.find("-std=") == std::string::npos)
params += " --std=c++17";
params += " -fno-gpu-rdc";
if(enabled(MIGRAPHX_GPU_DEBUG_SYM{}))
params += " -g";
params += " -c";
params += " --offload-arch=" + arch;
params += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
if(enabled(MIGRAPHX_GPU_DEBUG{}))
params += " -DMIGRAPHX_DEBUG";
params += " -Wno-unused-command-line-argument -Wno-cuda-compat ";
params += MIGRAPHX_HIP_COMPILER_FLAGS;
src_compiler compiler; src_compiler compiler;
compiler.flags = params; compiler.flags = params;
compiler.compiler = MIGRAPHX_HIP_COMPILER; compiler.compiler = MIGRAPHX_HIP_COMPILER;
...@@ -346,6 +333,23 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -346,6 +333,23 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(has_compiler_launcher()) if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_HIP_COMPILER_LAUNCHER; compiler.launcher = MIGRAPHX_HIP_COMPILER_LAUNCHER;
#endif #endif
if(params.find("-std=") == std::string::npos)
compiler.flags += " --std=c++17";
compiler.flags += " -fno-gpu-rdc";
if(enabled(MIGRAPHX_GPU_DEBUG_SYM{}))
compiler.flags += " -g";
compiler.flags += " -c";
compiler.flags += " --offload-arch=" + arch;
compiler.flags += " --cuda-device-only";
compiler.flags += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
if(enabled(MIGRAPHX_GPU_DEBUG{}))
compiler.flags += " -DMIGRAPHX_DEBUG";
compiler.flags += " -Wno-unused-command-line-argument -Wno-cuda-compat ";
compiler.flags += MIGRAPHX_HIP_COMPILER_FLAGS;
if(enabled(MIGRAPHX_GPU_DUMP_SRC{})) if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{ {
for(const auto& src : srcs) for(const auto& src : srcs)
......
...@@ -200,7 +200,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -200,7 +200,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options.params += " " + join_strings(compiler_warnings(), " "); options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0"; options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror"; options.params += " -Werror";
auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name()); auto cos = compile_hip_src(srcs, options.params, get_device_name());
if(cos.size() != 1) if(cos.size() != 1)
MIGRAPHX_THROW("No code object"); MIGRAPHX_THROW("No code object");
return code_object_op{value::binary{cos.front()}, return code_object_op{value::binary{cos.front()},
......
...@@ -43,24 +43,32 @@ template <index_int N, ...@@ -43,24 +43,32 @@ template <index_int N,
__device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, Output output) __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, Output output)
{ {
using type = decltype(input(deduce_for_stride(fs))); using type = decltype(input(deduce_for_stride(fs)));
MIGRAPHX_DEVICE_SHARED type buffer[N]; MIGRAPHX_DEVICE_SHARED type buffer[2][N];
type x = init; type x = init;
fs([&](auto i) { fs([&](auto i) {
index_int iout = 0;
index_int iin = 1;
if(idx.local == 0) if(idx.local == 0)
buffer[idx.local] = op(input(i), x); buffer[iout][idx.local] = op(input(i), x);
else else
buffer[idx.local] = input(i); buffer[iout][idx.local] = input(i);
__syncthreads(); __syncthreads();
for(index_int s = 1; s < idx.nlocal(); s *= 2) for(index_int s = 1; s < idx.nlocal(); s *= 2)
{ {
if(idx.local + s < idx.nlocal()) iout = 1 - iout;
iin = 1 - iin;
if(idx.local >= s)
{ {
buffer[idx.local + s] = op(buffer[idx.local], buffer[idx.local + s]); buffer[iout][idx.local] = op(buffer[iin][idx.local], buffer[iin][idx.local - s]);
}
else
{
buffer[iout][idx.local] = buffer[iin][idx.local];
} }
__syncthreads(); __syncthreads();
} }
x = buffer[idx.nlocal() - 1]; x = buffer[iout][idx.nlocal() - 1];
output(i, buffer[idx.local]); output(i, buffer[iout][idx.local]);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment