Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6f768035
Commit
6f768035
authored
Dec 06, 2023
by
Umang Yadav
Browse files
Merge branch 'rocblas_mlir_fp8' into miopen_fp8
parents
da7717ce
b2542239
Changes
208
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1191 additions
and
375 deletions
+1191
-375
src/onnx/include/migraphx/onnx/pooling.hpp
src/onnx/include/migraphx/onnx/pooling.hpp
+12
-17
src/onnx/onnx.proto
src/onnx/onnx.proto
+193
-61
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+23
-0
src/onnx/parse_multinomial.cpp
src/onnx/parse_multinomial.cpp
+3
-3
src/onnx/parse_pooling.cpp
src/onnx/parse_pooling.cpp
+11
-207
src/onnx/parse_qlinearpooling.cpp
src/onnx/parse_qlinearpooling.cpp
+115
-0
src/onnx/parse_qlinearunary.cpp
src/onnx/parse_qlinearunary.cpp
+151
-0
src/onnx/parse_scatternd.cpp
src/onnx/parse_scatternd.cpp
+7
-5
src/onnx/parse_unique.cpp
src/onnx/parse_unique.cpp
+92
-0
src/onnx/pooling.cpp
src/onnx/pooling.cpp
+247
-0
src/rewrite_pooling.cpp
src/rewrite_pooling.cpp
+131
-17
src/schedule.cpp
src/schedule.cpp
+2
-2
src/simplify_dyn_ops.cpp
src/simplify_dyn_ops.cpp
+134
-21
src/targets/cpu/dnnl.cpp
src/targets/cpu/dnnl.cpp
+2
-2
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+12
-1
src/targets/cpu/pooling.cpp
src/targets/cpu/pooling.cpp
+13
-4
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+0
-4
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+27
-23
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+1
-1
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
+15
-7
No files found.
src/include/migraphx/
eliminate_fp8
.hpp
→
src/
onnx/
include/migraphx/
onnx/pooling
.hpp
View file @
6f768035
...
@@ -21,31 +21,26 @@
...
@@ -21,31 +21,26 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_
ELIMINATE_FP8
_HPP
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_
ONNX_POOLING
_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_
ELIMINATE_FP8
_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_
ONNX_POOLING
_HPP
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/
shape
.hpp>
#include <migraphx/
onnx/onnx_parser
.hpp>
#include <
set
>
#include <
migraphx/onnx/op_parser.hpp
>
#include <
string
>
#include <
migraphx/instruction.hpp
>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
module
;
value
handle_pooling_values
(
const
op_desc
&
opd
,
onnx_parser
::
node_info
info
,
const
shape
&
in_shape
,
value
values
);
/**
instruction_ref
add_pooling_op
(
const
op_desc
&
opd
,
onnx_parser
::
node_info
info
,
instruction_ref
l0
);
This will insert convert operators for the operators that are not implemented for FP8 dtypes
*/
struct
MIGRAPHX_EXPORT
eliminate_fp8
{
// TODO: Add all device ops as a later PR and add tests for those.
std
::
set
<
std
::
string
>
op_names
;
shape
::
type_t
target_type
=
migraphx
::
shape
::
float_type
;
std
::
string
name
()
const
{
return
"eliminate_fp8"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/onnx/onnx.proto
View file @
6f768035
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
//
//
//
Copyright (c) ONNX Project Contributors.
//
SPDX-License-Identifier: Apache-2.0
// Licensed under the MIT license.
syntax
=
"proto2"
;
syntax
=
"proto2"
;
...
@@ -27,13 +27,6 @@ package onnx_for_migraphx;
...
@@ -27,13 +27,6 @@ package onnx_for_migraphx;
// Notes
// Notes
//
//
// Release
//
// We are still in the very early stage of defining ONNX. The current
// version of ONNX is a starting point. While we are actively working
// towards a complete spec, we would like to get the community involved
// by sharing our working version of ONNX.
//
// Protobuf compatibility
// Protobuf compatibility
//
//
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
...
@@ -92,15 +85,28 @@ enum Version {
...
@@ -92,15 +85,28 @@ enum Version {
// - Add sparse initializers
// - Add sparse initializers
IR_VERSION_2019_9_19
=
0x0000000000000006
;
IR_VERSION_2019_9_19
=
0x0000000000000006
;
// IR VERSION 7 published on <TBD>
// IR VERSION 7 published on May 8, 2020
// - Add support to allow function body graph to rely on multiple external opreator sets.
// - Add a list to promote inference graph's initializers to global and
// - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the
// mutable variables. Global variables are visible in all graphs of the
// stored models.
// stored models.
// - Add message TrainingInfoProto to store initialization
// - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto
// method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables.
// can modify the values of mutable variables.
// - Make inference graph callable from TrainingInfoProto via GraphCall operator.
// - Implicitly add inference graph into each TrainingInfoProto's algorithm.
IR_VERSION
=
0x0000000000000007
;
IR_VERSION_2020_5_8
=
0x0000000000000007
;
// IR VERSION 8 published on July 30, 2021
// Introduce TypeProto.SparseTensor
// Introduce TypeProto.Optional
// Added a list of FunctionProtos local to the model
// Deprecated since_version and operator status from FunctionProto
IR_VERSION_2021_7_30
=
0x0000000000000008
;
// IR VERSION 9 published on TBD
// Added AttributeProto to FunctionProto so that default attribute values can be set.
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
IR_VERSION
=
0x0000000000000009
;
}
}
// Attributes
// Attributes
...
@@ -121,6 +127,7 @@ message AttributeProto {
...
@@ -121,6 +127,7 @@ message AttributeProto {
TENSOR
=
4
;
TENSOR
=
4
;
GRAPH
=
5
;
GRAPH
=
5
;
SPARSE_TENSOR
=
11
;
SPARSE_TENSOR
=
11
;
TYPE_PROTO
=
13
;
FLOATS
=
6
;
FLOATS
=
6
;
INTS
=
7
;
INTS
=
7
;
...
@@ -128,6 +135,7 @@ message AttributeProto {
...
@@ -128,6 +135,7 @@ message AttributeProto {
TENSORS
=
9
;
TENSORS
=
9
;
GRAPHS
=
10
;
GRAPHS
=
10
;
SPARSE_TENSORS
=
12
;
SPARSE_TENSORS
=
12
;
TYPE_PROTOS
=
14
;
}
}
// The name field MUST be present for this version of the IR.
// The name field MUST be present for this version of the IR.
...
@@ -159,6 +167,7 @@ message AttributeProto {
...
@@ -159,6 +167,7 @@ message AttributeProto {
optional
SparseTensorProto
sparse_tensor
=
22
;
// sparse tensor value
optional
SparseTensorProto
sparse_tensor
=
22
;
// sparse tensor value
// Do not use field below, it's deprecated.
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
// optional ValueProto v = 12; // value - subsumes everything but graph
optional
TypeProto
tp
=
14
;
// type proto
repeated
float
floats
=
7
;
// list of floats
repeated
float
floats
=
7
;
// list of floats
repeated
int64
ints
=
8
;
// list of ints
repeated
int64
ints
=
8
;
// list of ints
...
@@ -166,6 +175,7 @@ message AttributeProto {
...
@@ -166,6 +175,7 @@ message AttributeProto {
repeated
TensorProto
tensors
=
10
;
// list of tensors
repeated
TensorProto
tensors
=
10
;
// list of tensors
repeated
GraphProto
graphs
=
11
;
// list of graph
repeated
GraphProto
graphs
=
11
;
// list of graph
repeated
SparseTensorProto
sparse_tensors
=
23
;
// list of sparse tensors
repeated
SparseTensorProto
sparse_tensors
=
23
;
// list of sparse tensors
repeated
TypeProto
type_protos
=
15
;
// list of type protos
}
}
// Defines information on value, including the name, the type, and
// Defines information on value, including the name, the type, and
...
@@ -211,7 +221,7 @@ message NodeProto {
...
@@ -211,7 +221,7 @@ message NodeProto {
// TrainingInfoProto stores information for training a model.
// TrainingInfoProto stores information for training a model.
// In particular, this defines two functionalities: an initialization-step
// In particular, this defines two functionalities: an initialization-step
// and a training-algorithm-step. Initialization resets the model
// and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been
consu
med.
// back to its original state as if no training has been
perfor
med.
// Training algorithm improves the model based on input data.
// Training algorithm improves the model based on input data.
//
//
// The semantics of the initialization-step is that the initializers
// The semantics of the initialization-step is that the initializers
...
@@ -224,8 +234,8 @@ message NodeProto {
...
@@ -224,8 +234,8 @@ message NodeProto {
// training algorithm's step. After the execution of a
// training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains
// may be immediately updated. If the targeted training algorithm contains
// consecutive update st
ag
es (such as block coordinate descent methods),
// consecutive update ste
p
s (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each st
ag
e.
// the user needs to create a TrainingInfoProto for each ste
p
.
message
TrainingInfoProto
{
message
TrainingInfoProto
{
// This field describes a graph to compute the initial tensors
// This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input
// upon starting the training process. Initialization graph has no input
...
@@ -239,20 +249,38 @@ message TrainingInfoProto {
...
@@ -239,20 +249,38 @@ message TrainingInfoProto {
// iteration to zero.
// iteration to zero.
//
//
// By default, this field is an empty graph and its evaluation does not
// By default, this field is an empty graph and its evaluation does not
// produce any output.
// produce any output.
Thus, no initializer would be changed by default.
optional
GraphProto
initialization
=
1
;
optional
GraphProto
initialization
=
1
;
// This field represents a training algorithm step. Given required inputs,
// This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's
// it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this graph contains loss node, gradient node,
// initializer lists. In general, this field contains loss node, gradient node,
// optimizer node, increment of iteration count, and some calls to the inference
// optimizer node, increment of iteration count.
// graph.
//
//
// The field algorithm.node is the only place the user can use GraphCall
// An execution of the training algorithm step is performed by executing the
// operator. The only callable graph is the one stored in ModelProto.graph.
// graph obtained by combining the inference graph (namely "ModelProto.graph")
// and the "algorithm" graph. That is, the actual the actual
// input/initializer/output/node/value_info/sparse_initializer list of
// the training graph is the concatenation of
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
// in that order. This combined graph must satisfy the normal ONNX conditions.
// Now, let's provide a visualization of graph combination for clarity.
// Let the inference graph (i.e., "ModelProto.graph") be
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
// and the "algorithm" graph be
// tensor_d -> Add -> tensor_e
// The combination process results
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
//
// Notice that an input of a node in the "algorithm" graph may reference the
// output of a node in the inference graph (but not the other way round). Also, inference
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
// can always be run independently without training information.
//
//
// By default, this field is an empty graph and its evaluation does not
// By default, this field is an empty graph and its evaluation does not
// produce any output.
// produce any output. Evaluating the default training step never
// update any initializers.
optional
GraphProto
algorithm
=
2
;
optional
GraphProto
algorithm
=
2
;
// This field specifies the bindings from the outputs of "initialization" to
// This field specifies the bindings from the outputs of "initialization" to
...
@@ -284,23 +312,16 @@ message TrainingInfoProto {
...
@@ -284,23 +312,16 @@ message TrainingInfoProto {
// be multiple key-value pairs in "update_binding".
// be multiple key-value pairs in "update_binding".
//
//
// The initializers appears as keys in "update_binding" are considered
// The initializers appears as keys in "update_binding" are considered
// mutable
and globally-visible
variables. This implies some behaviors
// mutable variables. This implies some behaviors
// as described below.
// as described below.
//
//
// 1. We have only unique keys in all "update_binding"s so that two
global
// 1. We have only unique keys in all "update_binding"s so that two
// variables may not have the same name. This ensures that one
// variables may not have the same name. This ensures that one
//
global
variable is assigned up to once.
// variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer".
// "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm".
// 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
// 4. If an optional input of a graph is omitted when using GraphCall, the
// 4. Mutable variables are initialized to the value specified by the
// global variable with the same name may be used.
// 5. When using GraphCall, the users always can pass values to optional
// inputs of the called graph even if the associated initializers appears
// as keys in "update_binding"s.
// 6. The graphs in TrainingInfoProto's can use global variables as
// their operator inputs.
// 7. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by
// corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
//
//
...
@@ -375,13 +396,31 @@ message ModelProto {
...
@@ -375,13 +396,31 @@ message ModelProto {
//
//
// If this field is empty, the training behavior of the model is undefined.
// If this field is empty, the training behavior of the model is undefined.
repeated
TrainingInfoProto
training_info
=
20
;
repeated
TrainingInfoProto
training_info
=
20
;
// A list of function protos local to the model.
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard opserator sets are given higher priotity or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto and other model local FunctionProtos.
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
// or by 2 FunctionProtos then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same for every node in the function body.
//
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated
FunctionProto
functions
=
25
;
};
};
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message
StringStringEntryProto
{
message
StringStringEntryProto
{
optional
string
key
=
1
;
optional
string
key
=
1
;
optional
string
value
=
2
;
optional
string
value
=
2
;
};
};
message
TensorAnnotation
{
message
TensorAnnotation
{
...
@@ -409,8 +448,9 @@ message GraphProto {
...
@@ -409,8 +448,9 @@ message GraphProto {
optional
string
name
=
2
;
// namespace Graph
optional
string
name
=
2
;
// namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
// A list of named tensor values, used to specify constant inputs of the graph.
// Each TensorProto entry must have a distinct name (within the list) that
// Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
// MAY also appear in the input list.
// The name MUST be unique across both initializer and sparse_initializer,
// but the name MAY also appear in the input list.
repeated
TensorProto
initializer
=
5
;
repeated
TensorProto
initializer
=
5
;
// Initializers (see above) stored in sparse format.
// Initializers (see above) stored in sparse format.
...
@@ -433,13 +473,8 @@ message GraphProto {
...
@@ -433,13 +473,8 @@ message GraphProto {
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated
TensorAnnotation
quantization_annotation
=
14
;
repeated
TensorAnnotation
quantization_annotation
=
14
;
// DO NOT USE the following fields, they were deprecated from earlier versions.
reserved
3
,
4
,
6
to
9
;
// repeated string input = 3;
reserved
"ir_version"
,
"producer_version"
,
"producer_tag"
,
"domain"
;
// repeated string output = 4;
// optional int64 ir_version = 6;
// optional int64 producer_version = 7;
// optional string producer_tag = 8;
// optional string domain = 9;
}
}
// Tensors
// Tensors
...
@@ -474,6 +509,17 @@ message TensorProto {
...
@@ -474,6 +509,17 @@ message TensorProto {
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16
=
16
;
BFLOAT16
=
16
;
// Non-IEEE floating-point format based on papers
// FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
// 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
// The computation usually happens inside a block quantize / dequantize
// fused by the runtime.
FLOAT8E4M3FN
=
17
;
// float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ
=
18
;
// float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E5M2
=
19
;
// follows IEEE 754, supports nan, inf, mostly used for gradients
FLOAT8E5M2FNUZ
=
20
;
// follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
// Future extensions go here.
// Future extensions go here.
}
}
...
@@ -507,11 +553,11 @@ message TensorProto {
...
@@ -507,11 +553,11 @@ message TensorProto {
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated
float
float_data
=
4
[
packed
=
true
];
repeated
float
float_data
=
4
[
packed
=
true
];
// For int32, uint8, int8, uint16, int16, bool, and float16 values
// For int32, uint8, int8, uint16, int16, bool,
float8,
and float16 values
// float16 values must be bit-wise converted to an uint16_t prior
// float16
and float8
values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// to writing to the buffer.
// When this field is present, the data_type field MUST be
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL,
or
FLOAT16
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16
, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated
int32
int32_data
=
5
[
packed
=
true
];
repeated
int32
int32_data
=
5
[
packed
=
true
];
// For strings.
// For strings.
...
@@ -589,6 +635,8 @@ message TensorProto {
...
@@ -589,6 +635,8 @@ message TensorProto {
message
SparseTensorProto
{
message
SparseTensorProto
{
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors.
// The default-value is zero for numeric tensors, and empty-string for string tensors.
// values must have a non-empty name present which serves as a name for SparseTensorProto
// when used in sparse_initializer list.
optional
TensorProto
values
=
1
;
optional
TensorProto
values
=
1
;
// The indices of the non-default values, which may be stored in one of two formats.
// The indices of the non-default values, which may be stored in one of two formats.
...
@@ -619,7 +667,7 @@ message TensorShapeProto {
...
@@ -619,7 +667,7 @@ message TensorShapeProto {
// Standard denotation can optionally be used to denote tensor
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
// that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/ma
ster
/docs/DimensionDenotation.md#denotation-definition
// Refer to https://github.com/onnx/onnx/blob/ma
in
/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations.
// for pre-defined dimension denotations.
optional
string
denotation
=
3
;
optional
string
denotation
=
3
;
};
};
...
@@ -656,6 +704,23 @@ message TypeProto {
...
@@ -656,6 +704,23 @@ message TypeProto {
optional
TypeProto
value_type
=
2
;
optional
TypeProto
value_type
=
2
;
};
};
// wrapper for Tensor, Sequence, or Map
message
Optional
{
// The type and optional shape of the element wrapped.
// This field MUST be present for this version of the IR.
// Possible values correspond to OptionalProto.DataType enum
optional
TypeProto
elem_type
=
1
;
};
message
SparseTensor
{
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
optional
int32
elem_type
=
1
;
optional
TensorShapeProto
shape
=
2
;
}
oneof
value
{
oneof
value
{
// The type of a tensor.
// The type of a tensor.
...
@@ -672,11 +737,18 @@ message TypeProto {
...
@@ -672,11 +737,18 @@ message TypeProto {
// The type of a map.
// The type of a map.
Map
map_type
=
5
;
Map
map_type
=
5
;
// The type of an optional.
Optional
optional_type
=
9
;
// Type of the sparse tensor
SparseTensor
sparse_tensor_type
=
8
;
}
}
// An optional denotation can be used to denote the whole
// An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is
// type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/ma
ster
/docs/TypeDenotation.md#type-denotation-definition
// stored inside. Refer to https://github.com/onnx/onnx/blob/ma
in
/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations.
// for pre-defined type denotations.
optional
string
denotation
=
6
;
optional
string
denotation
=
6
;
}
}
...
@@ -696,7 +768,67 @@ message OperatorSetIdProto {
...
@@ -696,7 +768,67 @@ message OperatorSetIdProto {
optional
int64
version
=
2
;
optional
int64
version
=
2
;
}
}
// Operator/function status.
enum
OperatorStatus
{
EXPERIMENTAL
=
0
;
STABLE
=
1
;
}
message
FunctionProto
{
// The name of the function, similar usage of op_type in OperatorProto.
// Combined with FunctionProto.domain, this forms the unique identity of
// the FunctionProto.
optional
string
name
=
1
;
// Deprecated since IR Version 8
// optional int64 since_version = 2;
reserved
2
;
reserved
"since_version"
;
// Deprecated since IR Version 8
// optional OperatorStatus status = 3;
reserved
3
;
reserved
"status"
;
// The inputs and outputs of the function.
repeated
string
input
=
4
;
repeated
string
output
=
5
;
// The attribute parameters of the function.
// It is for function parameters without default values.
repeated
string
attribute
=
6
;
// The attribute protos of the function.
// It is for function attributes with default values.
// A function attribute shall be represented either as
// a string attribute or an AttributeProto, not both.
repeated
AttributeProto
attribute_proto
=
11
;
// The nodes in the function.
repeated
NodeProto
node
=
7
;
// A human-readable documentation for this function. Markdown is allowed.
optional
string
doc_string
=
8
;
// The OperatorSets this function body (graph) relies on.
//
// All nodes in the function body (graph) will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets. This means at most one version can be relied
// for one domain.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
// and ModelProto then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same.
repeated
OperatorSetIdProto
opset_import
=
9
;
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
optional
string
domain
=
10
;
}
// For using protobuf-lite
// For using protobuf-lite
option
optimize_for
=
LITE_RUNTIME
;
option
optimize_for
=
LITE_RUNTIME
;
\ No newline at end of file
src/onnx/onnx_parser.cpp
View file @
6f768035
...
@@ -34,7 +34,9 @@
...
@@ -34,7 +34,9 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <onnx.pb.h>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
...
@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
case
onnx
::
AttributeProto
::
TENSORS
:
case
onnx
::
AttributeProto
::
TENSORS
:
case
onnx
::
AttributeProto
::
SPARSE_TENSOR
:
case
onnx
::
AttributeProto
::
SPARSE_TENSOR
:
case
onnx
::
AttributeProto
::
SPARSE_TENSORS
:
case
onnx
::
AttributeProto
::
SPARSE_TENSORS
:
case
onnx
::
AttributeProto
::
TYPE_PROTOS
:
case
onnx
::
AttributeProto
::
TYPE_PROTO
:
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
}
MIGRAPHX_THROW
(
"PARSE_VALUE: Invalid attribute type "
+
std
::
to_string
(
attr
.
type
()));
MIGRAPHX_THROW
(
"PARSE_VALUE: Invalid attribute type "
+
std
::
to_string
(
attr
.
type
()));
...
@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
...
@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case
onnx
::
TensorProto
::
DOUBLE
:
case
onnx
::
TensorProto
::
DOUBLE
:
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
t
.
float_data
());
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
t
.
float_data
());
case
onnx
::
TensorProto
::
FLOAT8E4M3FNUZ
:
{
std
::
vector
<
int32_t
>
data_int32
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
std
::
vector
<
migraphx
::
fp8
::
fp8e4m3fnuz
>
data_fp8
;
std
::
transform
(
data_int32
.
begin
(),
data_int32
.
end
(),
std
::
back_inserter
(
data_fp8
),
[](
float
raw_val
)
{
return
migraphx
::
fp8
::
fp8e4m3fnuz
{
raw_val
};
});
return
create_literal
(
shape
::
fp8e4m3fnuz_type
,
dims
,
data_fp8
);
}
case
onnx
::
TensorProto
::
FLOAT8E5M2FNUZ
:
case
onnx
::
TensorProto
::
FLOAT8E5M2
:
case
onnx
::
TensorProto
::
FLOAT8E4M3FN
:
case
onnx
::
TensorProto
::
UNDEFINED
:
case
onnx
::
TensorProto
::
UNDEFINED
:
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
COMPLEX64
:
case
onnx
::
TensorProto
::
COMPLEX64
:
...
@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype)
...
@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype)
case
11
:
return
shape
::
double_type
;
case
11
:
return
shape
::
double_type
;
case
12
:
return
shape
::
uint32_type
;
case
12
:
return
shape
::
uint32_type
;
case
13
:
return
shape
::
uint64_type
;
case
13
:
return
shape
::
uint64_type
;
case
18
:
return
shape
::
fp8e4m3fnuz_type
;
case
14
:
case
15
:
case
16
:
case
17
:
case
19
:
case
20
:
default:
{
default:
{
MIGRAPHX_THROW
(
"Prototensor data type "
+
std
::
to_string
(
dtype
)
+
" not supported"
);
MIGRAPHX_THROW
(
"Prototensor data type "
+
std
::
to_string
(
dtype
)
+
" not supported"
);
}
}
...
...
src/onnx/parse_multinomial.cpp
View file @
6f768035
...
@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
...
@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
// use literal. The array populated by random_uniform may have any shape, as long its
// use literal. The array populated by random_uniform may have any shape, as long its
// number of elements is batch_size * sample_size .
// number of elements is batch_size * sample_size .
size_t
batch_size
=
s0
.
lens
().
front
();
size_t
batch_size
=
s0
.
lens
().
front
();
auto
rand_dummy
=
info
.
add_literal
(
auto
rand_dummy
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
batch_size
*
sample_size
}}
);
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
sample_size
}}
,
std
::
vector
<
float
>
(
batch_size
*
sample_size
)});
randoms
=
randoms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"random_uniform"
),
seed_input
,
rand_dummy
);
info
.
add_instruction
(
migraphx
::
make_op
(
"random_uniform"
),
seed_input
,
rand_dummy
);
}
}
...
...
src/onnx/parse_pooling.cpp
View file @
6f768035
...
@@ -22,14 +22,8 @@
...
@@ -22,14 +22,8 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -39,68 +33,14 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -39,68 +33,14 @@ struct parse_pooling : op_parser<parse_pooling>
{
{
std
::
vector
<
op_desc
>
operators
()
const
std
::
vector
<
op_desc
>
operators
()
const
{
{
return
{{
"AveragePool"
,
"average"
},
return
{
{
"AveragePool"
,
"average"
},
{
"GlobalAveragePool"
,
"average"
},
{
"GlobalAveragePool"
,
"average"
},
{
"GlobalMaxPool"
,
"max"
},
{
"GlobalMaxPool"
,
"max"
},
{
"MaxPool"
,
"max"
},
{
"MaxPool"
,
"max"
},
{
"LpPool"
,
"lpnorm"
},
{
"LpPool"
,
"lpnorm"
},
{
"GlobalLpPool"
,
"lpnorm"
}};
{
"GlobalLpPool"
,
"lpnorm"
},
}
};
value
handle_values
(
const
op_desc
&
opd
,
onnx_parser
::
node_info
info
,
const
shape
&
in_shape
,
value
values
)
const
{
auto
kdims
=
in_shape
.
ndim
()
-
2
;
if
(
starts_with
(
opd
.
onnx_name
,
"Global"
))
{
// if spatial dimensions are dynamic use dyn_global flag
if
(
in_shape
.
dynamic
()
and
std
::
any_of
(
in_shape
.
dyn_dims
().
cbegin
()
+
2
,
in_shape
.
dyn_dims
().
cend
(),
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
}))
{
values
[
"dyn_global"
]
=
true
;
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
();
}
else
{
// works with static and fixed dynamic shape
auto
m_lens
=
in_shape
.
max_lens
();
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
(
m_lens
.
begin
()
+
2
,
m_lens
.
end
());
}
}
if
(
contains
(
info
.
attributes
,
"ceil_mode"
))
{
values
[
"ceil_mode"
]
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"ceil_mode"
).
i
());
}
if
(
contains
(
info
.
attributes
,
"strides"
))
{
values
[
"stride"
].
clear
();
copy
(
info
.
attributes
[
"strides"
].
ints
(),
std
::
back_inserter
(
values
[
"stride"
]));
check_attr_sizes
(
kdims
,
values
[
"stride"
].
size
(),
"PARSE_POOLING: inconsistent strides"
);
}
if
(
contains
(
info
.
attributes
,
"kernel_shape"
))
{
values
[
"lengths"
].
clear
();
copy
(
info
.
attributes
[
"kernel_shape"
].
ints
(),
std
::
back_inserter
(
values
[
"lengths"
]));
check_attr_sizes
(
kdims
,
values
[
"lengths"
].
size
(),
"PARSE_POOLING: inconsistent lengths"
);
}
// lp_order attribute
if
(
contains
(
info
.
attributes
,
"p"
))
{
values
[
"lp_order"
]
=
info
.
attributes
.
at
(
"p"
).
i
();
}
// ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode
(
info
,
"POOLING"
);
return
values
;
}
}
instruction_ref
parse
(
const
op_desc
&
opd
,
instruction_ref
parse
(
const
op_desc
&
opd
,
...
@@ -108,144 +48,8 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -108,144 +48,8 @@ struct parse_pooling : op_parser<parse_pooling>
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
std
::
string
mode
=
opd
.
op_name
;
return
add_pooling_op
(
opd
,
std
::
move
(
info
),
args
[
0
]);
const
std
::
unordered_map
<
std
::
string
,
op
::
pooling_mode
>
mode_map
=
{
};
{
"max"
,
op
::
pooling_mode
::
max
},
{
"average"
,
op
::
pooling_mode
::
average
},
{
"lpnorm"
,
op
::
pooling_mode
::
lpnorm
}};
if
(
not
contains
(
mode_map
,
mode
))
{
MIGRAPHX_THROW
(
"PARSE_POOLING: onnx pooling mode must be [
\"
max
\"
,
\"
average
\"
,
\"
lpnorm
\"
]"
);
}
operation
op
=
make_op
(
"pooling"
,
{{
"mode"
,
mode_map
.
at
(
mode
)}});
value
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
in_shape
=
l0
->
get_shape
();
assert
(
in_shape
.
ndim
()
>
2
);
auto
kdims
=
in_shape
.
ndim
()
-
2
;
values
=
handle_values
(
opd
,
info
,
in_shape
,
values
);
// count include padding, if count include pad is 1, we always use
// explicit pad
int
count_include_pad
=
0
;
if
(
contains
(
info
.
attributes
,
"count_include_pad"
))
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape"
);
}
count_include_pad
=
info
.
attributes
.
at
(
"count_include_pad"
).
i
();
}
std
::
vector
<
int64_t
>
paddings
;
float
pad_val
=
((
mode
==
"max"
)
?
std
::
numeric_limits
<
float
>::
lowest
()
:
0.0
f
);
if
(
contains
(
info
.
attributes
,
"pads"
))
{
values
[
"padding"
].
clear
();
copy
(
info
.
attributes
[
"pads"
].
ints
(),
std
::
back_inserter
(
paddings
));
check_attr_sizes
(
kdims
,
paddings
.
size
()
/
2
,
"PARSE_POOLING: inconsistent explicit paddings"
);
}
if
(
paddings
.
size
()
!=
2
*
kdims
)
{
paddings
.
resize
(
kdims
*
2
);
std
::
fill_n
(
paddings
.
begin
(),
2
*
kdims
,
0
);
}
if
(
values
[
"padding"
].
size
()
!=
kdims
)
{
values
[
"padding"
].
resize
(
kdims
);
std
::
fill_n
(
values
[
"padding"
].
begin
(),
kdims
,
0
);
}
if
(
values
[
"stride"
].
size
()
!=
kdims
)
{
values
[
"stride"
].
resize
(
kdims
);
std
::
fill_n
(
values
[
"stride"
].
begin
(),
kdims
,
1
);
}
// used to calculate the supposed output shape
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
// TODO: add parsing for dilations
if
(
contains
(
info
.
attributes
,
"auto_pad"
)
and
to_upper
(
info
.
attributes
[
"auto_pad"
].
s
())
!=
"NOTSET"
)
{
auto
auto_pad
=
to_upper
(
info
.
attributes
[
"auto_pad"
].
s
());
// don't use the given padding sizes, if any
// values["padding"].clear();
if
(
in_shape
.
dynamic
())
{
// set padding_mode to trigger auto padding at runtime
bool
is_same_upper
=
(
auto_pad
.
find
(
"SAME_UPPER"
)
!=
std
::
string
::
npos
);
values
[
"padding_mode"
]
=
is_same_upper
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size
(
info
,
values
,
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
std
::
vector
<
size_t
>
(
in_shape
.
ndim
()
-
2
,
1
),
in_shape
.
lens
(),
paddings
);
values
[
"padding"
]
=
paddings
;
// default padding_mode indicates that padding sizes are not calculated dynamically
values
[
"padding_mode"
]
=
migraphx
::
op
::
padding_mode_t
::
default_
;
}
}
std
::
vector
<
int64_t
>
slice_start
;
std
::
vector
<
int64_t
>
slice_end
;
tune_padding_size
(
values
,
paddings
,
count_include_pad
,
slice_start
);
if
(
not
slice_start
.
empty
())
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape"
);
}
// calculate expected output shape
orig_padding
.
insert
(
orig_padding
.
begin
()
+
kdims
,
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
op
::
pad
pad
{
orig_padding
,
0.0
f
};
shape
padded_shape
=
pad
.
compute_shape
({
l0
->
get_shape
()});
// make an op just to get its output shape
auto
out_lens
=
make_op
(
"pooling"
,
values
).
compute_shape
({
padded_shape
}).
lens
();
// compute slice_end information
slice_end
.
resize
(
slice_start
.
size
());
std
::
transform
(
out_lens
.
begin
()
+
2
,
out_lens
.
end
(),
slice_start
.
begin
(),
slice_end
.
begin
(),
[](
auto
i
,
auto
j
)
{
return
i
+
j
;
});
}
values
[
"padding"
]
=
std
::
vector
<
size_t
>
(
paddings
.
begin
(),
paddings
.
end
());
check_asym_padding
(
info
,
l0
,
paddings
,
values
,
count_include_pad
,
pad_val
);
op
.
from_value
(
values
);
auto
l1
=
info
.
add_instruction
(
op
,
l0
);
if
(
not
slice_start
.
empty
())
{
std
::
vector
<
int64_t
>
axes
(
kdims
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
2
);
l1
=
info
.
add_instruction
(
make_op
(
"slice"
,
{{
"axes"
,
axes
},
{
"starts"
,
slice_start
},
{
"ends"
,
slice_end
}}),
l1
);
}
return
l1
;
}
};
};
}
// namespace onnx
}
// namespace onnx
...
...
src/onnx/parse_qlinear
glavg
pool.cpp
→
src/onnx/parse_qlinearpool
ing
.cpp
View file @
6f768035
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
*/
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
...
@@ -36,90 +37,56 @@ namespace onnx {
...
@@ -36,90 +37,56 @@ namespace onnx {
/*
/*
*********************************************************************************
*********************************************************************************
* Reference: see QLinear
GlobalAveragePool in
*
* Reference: see QLinear
AveragePool and QLinearGlobalAveragePool in
*
* github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
* github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
*********************************************************************************
*/
QLinearGlobalAveragePool consumes an input tensor X and applies
struct
parse_qlinearpooling
:
op_parser
<
parse_qlinearpooling
>
Average pooling across the values in the same channel. This is
equivalent to AveragePool with kernel size equal to the spatial
dimension of input tensor. Input is of type uint8_t or int8_t.
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
Attributes
channels_last : int
Inputs
X : T
Input data tensor from the previous operator; According to channels_last, dimensions for image case
are (N x C x H x W), or (N x H x W x C) where N is the batch size, C is the number of channels, and
H and W are the height and the width of the data. For non image case, the dimensions are in the form
of (N x C x D1 x D2 ... Dn), or (N x D1 X D2 ... Dn x C) where N is the batch size.
x_scale : tensor(float)
Scale of quantized input 'X'. It must be a scalar.
x_zero_point : T
Zero point tensor for input 'X'. It must be a scalar.
y_scale : tensor(float)
Scale of quantized output 'Y'. It must be a scalar.
y_zero_point : T
Zero point tensor for output 'Y'. It must be a scalar.
Outputs
Y : T
Output data tensor from pooling across the input tensor. The output tensor has the same rank as the
input. with the N and C value keep it value, while the other dimensions are all 1. Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to signed/unsigned int8 tensors.
*/
struct
parse_qlinearglobalaveragepool
:
op_parser
<
parse_qlinearglobalaveragepool
>
{
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"QLinearGlobalAveragePool"
}};
}
std
::
vector
<
op_desc
>
operators
()
const
// basic type checking for QLinearGlobalAveragePool Operator
void
check_inputs
(
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
{
if
(
args
.
size
()
<
5
)
return
{{
"QLinearGlobalAveragePool"
,
"average"
},
{
"QLinearAveragePool"
,
"average"
}};
MIGRAPHX_THROW
(
"QLINEARGLOBALAVERAGEPOOL: missing inputs"
);
}
void
check_inputs
(
const
op_desc
&
opd
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
const
auto
&
in_x
=
args
[
0
];
const
auto
&
in_x
=
args
[
0
];
const
auto
&
zero_pt_x
=
args
[
2
];
const
auto
onnx_name
=
opd
.
onnx_name
;
const
auto
&
zero_pt_y
=
args
[
4
];
if
(
in_x
->
get_shape
().
ndim
()
<=
2
)
if
(
in_x
->
get_shape
().
ndim
()
<=
2
)
MIGRAPHX_THROW
(
"QLINEARGLOBALAVERAGEPOOL
: input dimensions too small"
);
MIGRAPHX_THROW
(
onnx_name
+
"
: input dimensions too small"
);
auto
type_x
=
in_x
->
get_shape
().
type
();
auto
type_x
=
in_x
->
get_shape
().
type
();
if
(
type_x
!=
migraphx
::
shape
::
int8_type
and
type_x
!=
migraphx
::
shape
::
uint8_type
)
if
(
type_x
!=
migraphx
::
shape
::
int8_type
and
type_x
!=
migraphx
::
shape
::
uint8_type
)
MIGRAPHX_THROW
(
"QLINEARGLOBALAVERAGEPOOL
: unsupported input type"
);
MIGRAPHX_THROW
(
onnx_name
+
"
: unsupported input type"
);
const
auto
&
zero_pt_x
=
args
[
2
];
if
(
type_x
!=
zero_pt_x
->
get_shape
().
type
())
if
(
type_x
!=
zero_pt_x
->
get_shape
().
type
())
MIGRAPHX_THROW
(
"QLINEARGLOBALAVERAGEPOOL
: mismatched type: input zero point"
);
MIGRAPHX_THROW
(
onnx_name
+
"
: mismatched type: input zero point"
);
if
(
args
.
size
()
==
5
)
{
const
auto
&
zero_pt_y
=
args
[
4
];
if
(
type_x
!=
zero_pt_y
->
get_shape
().
type
())
if
(
type_x
!=
zero_pt_y
->
get_shape
().
type
())
MIGRAPHX_THROW
(
"QLINEARGLOBALAVERAGEPOOL: mismatched type: output zero point"
);
MIGRAPHX_THROW
(
onnx_name
+
": mismatched type: output zero point"
);
}
}
}
instruction_ref
parse
(
const
op_desc
&
/* opd */
,
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
parser
,
const
onnx_parser
&
parser
,
const
onnx_parser
::
node_info
&
info
,
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
if
(
contains
(
info
.
attributes
,
"channel_last"
))
{
{
int
channels_last
=
int
channels_last
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"channels_last"
)).
template
at
<
int
>();
parser
.
parse_value
(
info
.
attributes
.
at
(
"channels_last"
)).
template
at
<
int
>();
if
(
channels_last
!=
0
)
if
(
channels_last
!=
0
)
MIGRAPHX_THROW
(
MIGRAPHX_THROW
(
opd
.
onnx_name
+
": channels_last (N x D1..Dn x C) is not supported"
);
"QLINEARGLOBALAVERAGEPOOL: channels_last (N x D1..Dn x C) is not supported"
);
}
check_inputs
(
args
);
check_inputs
(
opd
,
args
);
// Input: X
// Input: X
...
@@ -128,21 +95,18 @@ struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool
...
@@ -128,21 +95,18 @@ struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool
const
auto
&
zero_pt_x
=
args
[
2
];
const
auto
&
zero_pt_x
=
args
[
2
];
auto
dquant_x
=
bcast_qdq_instr
(
"dequantizelinear"
,
in_x
,
scale_x
,
zero_pt_x
,
info
);
auto
dquant_x
=
bcast_qdq_instr
(
"dequantizelinear"
,
in_x
,
scale_x
,
zero_pt_x
,
info
);
// Output Y = globalaveragepool(X)
// Output Y = pooling_op(X)
auto
op
=
migraphx
::
op
::
pooling
{
migraphx
::
op
::
pooling_mode
::
average
};
auto
lens
=
in_x
->
get_shape
().
lens
();
std
::
vector
<
size_t
>
lengths
(
lens
.
begin
()
+
2
,
lens
.
end
());
op
.
lengths
=
lengths
;
op
.
padding
=
std
::
vector
<
size_t
>
(
lens
.
size
());
auto
out_y
=
info
.
add_instruction
(
op
,
dquant_x
);
const
auto
&
scale_y
=
args
[
3
];
auto
out_y
=
add_pooling_op
(
opd
,
info
,
dquant_x
);
const
auto
&
zero_pt_y
=
args
[
4
];
auto
out_quant_y
=
bcast_qdq_instr
(
"quantizelinear"
,
out_y
,
scale_y
,
zero_pt_y
,
info
);
const
auto
&
in_scale_y
=
args
[
3
];
// zero_pt for Y is supplied as the last optional argument..
if
(
args
.
size
()
==
5
)
return
(
bcast_qdq_instr
(
"quantizelinear"
,
out_y
,
in_scale_y
,
args
[
4
],
info
));
return
out_quant_y
;
// if no zero_pt: just broadcast the scale..
auto
bcast_scale_y
=
bcast_scalar_instr
(
out_y
->
get_shape
(),
in_scale_y
,
info
);
return
(
info
.
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
out_y
,
bcast_scale_y
));
}
}
};
};
...
...
src/onnx/parse_qlinearunary.cpp
0 → 100644
View file @
6f768035
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
/*
*********************************************************************************
* Reference: see QLinearSigmoid, QLinearLeakyRelu in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
com.microsoft.QLinearSigmoid
QLinearSigmoid takes quantized input data (Tensor), and quantize parameter for output, and produces
one output data (Tensor) where the function f(x) = quantize(Sigmoid(dequantize(x))), is applied to
the data tensor elementwise. Where the function Sigmoid(x) = 1 / (1 + exp(-x))
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
*****************************************************************************************************
com.microsoft.QLinearLeakyRelu
QLinearLeakyRelu takes quantized input data (Tensor), an argument alpha, and quantize parameter for
output, and produces one output data (Tensor) where the function f(x) = quantize(alpha *
dequantize(x)) for dequantize(x) < 0, f(x) = quantize(dequantize(x)) for dequantize(x) >= 0, is
applied to the data tensor elementwise.
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
Attributes
alpha : float
Coefficient of leakage.
******************************************************************************************************
Generic input layout of QLinear unary operators:
Inputs (4 - 5)
X : T
Input tensor
X_scale : tensor(float)
Input X's scale. It's a scalar, which means a per-tensor/layer quantization.
X_zero_point (optional) : T
Input X's zero point. Default value is 0 if it's not specified. It's a scalar, which means a
per-tensor/layer quantization.
Y_scale : tensor(float) Output Y's scale. It's a scalar, which means
a per-tensor/layer quantization.
Y_zero_point (optional) : T Output Y's zero point. Default value is
0 if it's not specified. It's a scalar, which means a per-tensor/layer quantization.
Outputs
Y : T
Output tensor
Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to 8 bit tensors.
*/
struct
parse_qlinearunary
:
op_parser
<
parse_qlinearunary
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"QLinearSigmoid"
,
"sigmoid"
},
{
"QLinearLeakyRelu"
,
"leaky_relu"
}};
}
void
check_inputs
(
const
op_desc
&
opd
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
if
(
args
.
size
()
<
4
)
MIGRAPHX_THROW
(
opd
.
op_name
+
": missing inputs"
);
const
auto
&
in_x
=
args
[
0
];
auto
sh_x
=
in_x
->
get_shape
();
auto
type_x
=
sh_x
.
type
();
if
(
type_x
!=
migraphx
::
shape
::
int8_type
and
type_x
!=
migraphx
::
shape
::
uint8_type
)
MIGRAPHX_THROW
(
opd
.
op_name
+
": unsupported input type"
);
}
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
parser
,
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
check_inputs
(
opd
,
args
);
// X
const
auto
&
in_x
=
args
[
0
];
const
auto
&
in_scale_x
=
args
[
1
];
const
auto
&
in_zero_pt_x
=
args
[
2
];
auto
dquant_x
=
bcast_qdq_instr
(
"dequantizelinear"
,
in_x
,
in_scale_x
,
in_zero_pt_x
,
info
);
// Y = (op(dequantizelinear(x))
auto
op
=
parser
.
load
(
opd
.
op_name
,
info
);
auto
y
=
info
.
add_instruction
(
op
,
dquant_x
);
const
auto
&
in_scale_y
=
args
[
3
];
// zero_pt for Y is supplied as the last optional argument..
if
(
args
.
size
()
==
5
)
return
(
bcast_qdq_instr
(
"quantizelinear"
,
y
,
in_scale_y
,
args
[
4
],
info
));
// if no zero_pt: just broadcast the scale..
auto
bcast_scale_sigm
=
bcast_scalar_instr
(
y
->
get_shape
(),
in_scale_y
,
info
);
return
(
info
.
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
y
,
bcast_scale_sigm
));
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_scatternd.cpp
View file @
6f768035
...
@@ -39,15 +39,17 @@ struct parse_scatternd : op_parser<parse_scatternd>
...
@@ -39,15 +39,17 @@ struct parse_scatternd : op_parser<parse_scatternd>
const
onnx_parser
::
node_info
&
info
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>&
args
)
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
{
std
::
string
reduction
=
"none"
;
if
(
contains
(
info
.
attributes
,
"reduction"
))
if
(
contains
(
info
.
attributes
,
"reduction"
))
{
{
if
(
info
.
attributes
.
at
(
"reduction"
).
s
()
==
"add"
)
reduction
=
info
.
attributes
.
at
(
"reduction"
).
s
();
return
info
.
add_instruction
(
migraphx
::
make_op
(
"scatternd_add"
),
args
);
if
(
not
contains
({
"none"
,
"add"
,
"mul"
,
"min"
,
"max"
},
reduction
))
if
(
info
.
attributes
.
at
(
"reduction"
).
s
()
==
"mul"
)
{
return
info
.
add_instruction
(
migraphx
::
make_op
(
"scatternd_mul"
),
args
);
MIGRAPHX_THROW
(
"PARSE_SCATTERND: unsupported reduction mode "
+
reduction
);
}
}
}
return
info
.
add_instruction
(
migraphx
::
make_op
(
"scatternd_
none"
),
args
);
return
info
.
add_instruction
(
migraphx
::
make_op
(
"scatternd_
"
+
reduction
),
args
);
}
}
};
};
...
...
src/onnx/parse_unique.cpp
0 → 100644
View file @
6f768035
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <optional>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
// generate unique output stream y, given input stream x;
//
// case unsorted:
// input x: [2, 1, 1, 3, 4, 3], attr_sorted = 0;
// output(s):
// y: [2, 1, 3, 4] --- the unique output
// y_indices: [0, 1, 3, 4] --- first incidence, in terms of indices of x
// x_rev_indices: [0, 1, 1, 2, 3, 2] --- x seen in terms of indices of y
// y_count: [1, 2, 2, 1] -- count at each y_index. sum = len(x)
//
// case sorted:
// input x: [2, 1, 1, 3, 4, 3], attr_sorted = 1;
// output(s):
// y: [1, 2, 3, 4] --- the unique output
// y_indices: [1, 0, 3, 4] --- first incidence, in terms of indices of x
// x_rev_indices: [1, 0, 0, 2, 3, 2] --- x seen in terms of indices of y
// y_count: [2, 1, 2, 1] -- count at each y_index. sum = len(x)
struct
parse_unique
:
op_parser
<
parse_unique
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Unique"
}};
}
std
::
vector
<
instruction_ref
>
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
parser
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
int64_t
sorted
=
1
;
// default = sorted.
if
(
contains
(
info
.
attributes
,
"sorted"
))
sorted
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"sorted"
)).
at
<
int
>
();
std
::
optional
<
int64_t
>
axis
;
if
(
contains
(
info
.
attributes
,
"axis"
))
{
auto
n_dim
=
args
[
0
]
->
get_shape
().
ndim
();
axis
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
int
>
();
axis
=
tune_axis
(
n_dim
,
*
axis
,
opd
.
op_name
);
}
migraphx
::
argument
data_arg
=
args
.
back
()
->
eval
();
auto
opr
=
axis
?
migraphx
::
make_op
(
"unique"
,
{{
"axis"
,
*
axis
},
{
"sorted"
,
sorted
}})
:
migraphx
::
make_op
(
"unique"
,
{{
"sorted"
,
sorted
}});
auto
u_opr
=
info
.
add_instruction
(
opr
,
args
.
at
(
0
));
auto
i_y
=
info
.
add_instruction
(
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
u_opr
);
auto
i_y_idx
=
info
.
add_instruction
(
make_op
(
"get_tuple_elem"
,
{{
"index"
,
1
}}),
u_opr
);
auto
i_x_idx
=
info
.
add_instruction
(
make_op
(
"get_tuple_elem"
,
{{
"index"
,
2
}}),
u_opr
);
auto
i_count
=
info
.
add_instruction
(
make_op
(
"get_tuple_elem"
,
{{
"index"
,
3
}}),
u_opr
);
return
{
i_y
,
i_y_idx
,
i_x_idx
,
i_count
};
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/pooling.cpp
0 → 100644
View file @
6f768035
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
value
handle_pooling_values
(
const
op_desc
&
opd
,
onnx_parser
::
node_info
info
,
const
shape
&
in_shape
,
value
values
)
{
auto
kdims
=
in_shape
.
ndim
()
-
2
;
if
(
starts_with
(
opd
.
onnx_name
,
"Global"
)
or
starts_with
(
opd
.
onnx_name
,
"QLinearGlobal"
))
{
// if spatial dimensions are dynamic use dyn_global flag
if
(
in_shape
.
dynamic
()
and
std
::
any_of
(
in_shape
.
dyn_dims
().
cbegin
()
+
2
,
in_shape
.
dyn_dims
().
cend
(),
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
}))
{
values
[
"dyn_global"
]
=
true
;
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
();
}
else
{
// works with static and fixed dynamic shape
auto
m_lens
=
in_shape
.
max_lens
();
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
(
m_lens
.
begin
()
+
2
,
m_lens
.
end
());
}
}
if
(
contains
(
info
.
attributes
,
"ceil_mode"
))
{
values
[
"ceil_mode"
]
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"ceil_mode"
).
i
());
}
if
(
contains
(
info
.
attributes
,
"strides"
))
{
values
[
"stride"
].
clear
();
copy
(
info
.
attributes
[
"strides"
].
ints
(),
std
::
back_inserter
(
values
[
"stride"
]));
check_attr_sizes
(
kdims
,
values
[
"stride"
].
size
(),
"PARSE_POOLING: inconsistent strides"
);
}
if
(
contains
(
info
.
attributes
,
"kernel_shape"
))
{
values
[
"lengths"
].
clear
();
copy
(
info
.
attributes
[
"kernel_shape"
].
ints
(),
std
::
back_inserter
(
values
[
"lengths"
]));
check_attr_sizes
(
kdims
,
values
[
"lengths"
].
size
(),
"PARSE_POOLING: inconsistent lengths"
);
}
if
(
contains
(
info
.
attributes
,
"dilations"
))
{
values
[
"dilations"
].
clear
();
copy
(
info
.
attributes
[
"dilations"
].
ints
(),
std
::
back_inserter
(
values
[
"dilations"
]));
check_attr_sizes
(
kdims
,
values
[
"dilations"
].
size
(),
"PARSE_POOLING: inconsistent dilations"
);
}
// lp_order attribute
if
(
contains
(
info
.
attributes
,
"p"
))
{
values
[
"lp_order"
]
=
info
.
attributes
.
at
(
"p"
).
i
();
}
// ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode
(
info
,
"POOLING"
);
return
values
;
}
instruction_ref
add_pooling_op
(
const
op_desc
&
opd
,
onnx_parser
::
node_info
info
,
instruction_ref
l0
)
{
std
::
string
mode
=
opd
.
op_name
;
const
std
::
unordered_map
<
std
::
string
,
op
::
pooling_mode
>
mode_map
=
{
{
"max"
,
op
::
pooling_mode
::
max
},
{
"average"
,
op
::
pooling_mode
::
average
},
{
"lpnorm"
,
op
::
pooling_mode
::
lpnorm
}};
if
(
not
contains
(
mode_map
,
mode
))
{
MIGRAPHX_THROW
(
"PARSE_POOLING: onnx pooling mode must be [
\"
max
\"
,
\"
average
\"
,
\"
lpnorm
\"
]"
);
}
operation
op
=
make_op
(
"pooling"
,
{{
"mode"
,
mode_map
.
at
(
mode
)}});
value
values
=
op
.
to_value
();
auto
in_shape
=
l0
->
get_shape
();
assert
(
in_shape
.
ndim
()
>
2
);
auto
kdims
=
in_shape
.
ndim
()
-
2
;
values
=
handle_pooling_values
(
opd
,
info
,
in_shape
,
values
);
// count include padding, if count include pad is 1, we always use
// explicit pad
int
count_include_pad
=
0
;
if
(
contains
(
info
.
attributes
,
"count_include_pad"
))
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape"
);
}
count_include_pad
=
info
.
attributes
.
at
(
"count_include_pad"
).
i
();
}
std
::
vector
<
int64_t
>
paddings
;
float
pad_val
=
((
mode
==
"max"
)
?
std
::
numeric_limits
<
float
>::
lowest
()
:
0.0
f
);
if
(
contains
(
info
.
attributes
,
"pads"
))
{
values
[
"padding"
].
clear
();
copy
(
info
.
attributes
[
"pads"
].
ints
(),
std
::
back_inserter
(
paddings
));
check_attr_sizes
(
kdims
,
paddings
.
size
()
/
2
,
"PARSE_POOLING: inconsistent explicit paddings"
);
}
if
(
paddings
.
size
()
!=
2
*
kdims
)
{
paddings
.
resize
(
kdims
*
2
);
std
::
fill_n
(
paddings
.
begin
(),
2
*
kdims
,
0
);
}
if
(
values
[
"padding"
].
size
()
!=
kdims
)
{
values
[
"padding"
].
resize
(
kdims
);
std
::
fill_n
(
values
[
"padding"
].
begin
(),
kdims
,
0
);
}
if
(
values
[
"stride"
].
size
()
!=
kdims
)
{
values
[
"stride"
].
resize
(
kdims
);
std
::
fill_n
(
values
[
"stride"
].
begin
(),
kdims
,
1
);
}
if
(
values
[
"dilations"
].
size
()
!=
kdims
)
{
values
[
"dilations"
].
resize
(
kdims
);
std
::
fill_n
(
values
[
"dilations"
].
begin
(),
kdims
,
1
);
}
// used to calculate the supposed output shape
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
// TODO: add parsing for dilations
if
(
contains
(
info
.
attributes
,
"auto_pad"
)
and
to_upper
(
info
.
attributes
[
"auto_pad"
].
s
())
!=
"NOTSET"
)
{
auto
auto_pad
=
to_upper
(
info
.
attributes
[
"auto_pad"
].
s
());
// don't use the given padding sizes, if any
// values["padding"].clear();
if
(
in_shape
.
dynamic
())
{
// set padding_mode to trigger auto padding at runtime
bool
is_same_upper
=
(
auto_pad
.
find
(
"SAME_UPPER"
)
!=
std
::
string
::
npos
);
values
[
"padding_mode"
]
=
is_same_upper
?
to_value
(
op
::
padding_mode_t
::
same_upper
)
:
to_value
(
op
::
padding_mode_t
::
same_lower
);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size
(
info
,
values
,
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
values
[
"dilations"
].
to_vector
<
std
::
size_t
>
(),
in_shape
.
lens
(),
paddings
);
values
[
"padding"
]
=
paddings
;
// default padding_mode indicates that padding sizes are not calculated dynamically
values
[
"padding_mode"
]
=
migraphx
::
op
::
padding_mode_t
::
default_
;
}
}
std
::
vector
<
int64_t
>
slice_start
;
std
::
vector
<
int64_t
>
slice_end
;
tune_padding_size
(
values
,
paddings
,
count_include_pad
,
slice_start
);
if
(
not
slice_start
.
empty
())
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape"
);
}
// calculate expected output shape
orig_padding
.
insert
(
orig_padding
.
begin
()
+
kdims
,
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
op
::
pad
pad
{
orig_padding
,
0.0
f
};
shape
padded_shape
=
pad
.
compute_shape
({
l0
->
get_shape
()});
// make an op just to get its output shape
auto
out_lens
=
make_op
(
"pooling"
,
values
).
compute_shape
({
padded_shape
}).
lens
();
// compute slice_end information
slice_end
.
resize
(
slice_start
.
size
());
std
::
transform
(
out_lens
.
begin
()
+
2
,
out_lens
.
end
(),
slice_start
.
begin
(),
slice_end
.
begin
(),
[](
auto
i
,
auto
j
)
{
return
i
+
j
;
});
}
values
[
"padding"
]
=
std
::
vector
<
size_t
>
(
paddings
.
begin
(),
paddings
.
end
());
check_asym_padding
(
info
,
l0
,
paddings
,
values
,
count_include_pad
,
pad_val
);
op
.
from_value
(
values
);
auto
l1
=
info
.
add_instruction
(
op
,
l0
);
if
(
not
slice_start
.
empty
())
{
std
::
vector
<
int64_t
>
axes
(
kdims
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
2
);
l1
=
info
.
add_instruction
(
make_op
(
"slice"
,
{{
"axes"
,
axes
},
{
"starts"
,
slice_start
},
{
"ends"
,
slice_end
}}),
l1
);
}
return
l1
;
}
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/rewrite_pooling.cpp
View file @
6f768035
...
@@ -35,25 +35,14 @@
...
@@ -35,25 +35,14 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
rewrite_pooling
::
apply
(
module
&
m
)
const
static
void
replace_with_reduce
(
module
&
m
,
instruction_ref
ins
)
{
{
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
"pooling"
)
continue
;
if
(
ins
->
inputs
().
empty
())
continue
;
auto
&&
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
&&
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
if
(
not
std
::
all_of
(
op
.
padding
.
begin
(),
op
.
padding
.
end
(),
[](
auto
i
)
{
return
i
==
0
;
}))
continue
;
if
(
not
std
::
all_of
(
op
.
stride
.
begin
(),
op
.
stride
.
end
(),
[](
auto
i
)
{
return
i
==
1
;
}))
continue
;
auto
lens
=
s
.
lens
();
auto
lens
=
s
.
lens
();
if
(
not
std
::
equal
(
lens
.
begin
()
+
2
,
lens
.
end
(),
op
.
lengths
.
begin
(),
op
.
lengths
.
end
()))
continue
;
std
::
vector
<
std
::
int64_t
>
axes
(
lens
.
size
()
-
2
);
std
::
vector
<
std
::
int64_t
>
axes
(
lens
.
size
()
-
2
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
2
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
2
);
// average pooling
// average pooling
if
(
op
.
mode
==
op
::
pooling_mode
::
average
)
if
(
op
.
mode
==
op
::
pooling_mode
::
average
)
{
{
...
@@ -64,6 +53,131 @@ void rewrite_pooling::apply(module& m) const
...
@@ -64,6 +53,131 @@ void rewrite_pooling::apply(module& m) const
{
{
m
.
replace_instruction
(
ins
,
make_op
(
"reduce_max"
,
{{
"axes"
,
axes
}}),
ins
->
inputs
());
m
.
replace_instruction
(
ins
,
make_op
(
"reduce_max"
,
{{
"axes"
,
axes
}}),
ins
->
inputs
());
}
}
}
static
void
replace_dilations_with_gather_pooling
(
module
&
m
,
instruction_ref
ins
)
{
// TODO remove this when MIOpen supports dilated pooling
auto
&&
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
// Ignore N, C axes
std
::
vector
<
size_t
>
dims
=
{
s
.
lens
().
cbegin
()
+
2
,
s
.
lens
().
cend
()};
bool
default_padding
=
std
::
all_of
(
op
.
padding
.
cbegin
(),
op
.
padding
.
cend
(),
[](
auto
i
)
{
return
i
==
0
;
});
if
(
not
default_padding
)
{
for
(
size_t
idx
{
0
};
idx
<
op
.
padding
.
size
();
++
idx
)
{
// We need to pad both ends
dims
[
idx
]
+=
op
.
padding
.
at
(
idx
)
*
2
;
}
}
std
::
vector
<
size_t
>
kernels
=
op
.
lengths
;
std
::
vector
<
size_t
>
strides
=
op
.
stride
;
std
::
vector
<
size_t
>
dilations
=
op
.
dilations
;
std
::
vector
<
std
::
vector
<
int
>>
axis_indices
;
axis_indices
.
resize
(
dims
.
size
());
for
(
auto
idx
{
0
};
idx
<
dims
.
size
();
++
idx
)
{
// Only consider if iw fits into the window
for
(
size_t
stride
{
0
};
stride
<
dims
.
at
(
idx
)
-
dilations
.
at
(
idx
)
*
(
kernels
.
at
(
idx
)
-
1
);
stride
+=
strides
.
at
(
idx
))
{
for
(
size_t
step
{
0
};
step
<
kernels
.
at
(
idx
);
++
step
)
{
axis_indices
.
at
(
idx
).
push_back
(
stride
+
dilations
.
at
(
idx
)
*
step
);
}
}
}
auto
elements
=
ins
->
inputs
().
front
();
if
(
not
default_padding
)
{
// Pad supports asym, we need to provide both ends
std
::
vector
<
size_t
>
padding
(
2
*
s
.
lens
().
size
(),
0
);
// Format will be e.g {N, C, P1, P2, N, C, P1, P2}
for
(
size_t
idx
{
0
};
idx
<
op
.
padding
.
size
();
++
idx
)
{
// Ignore N, C axes
padding
.
at
(
2
+
idx
)
=
op
.
padding
.
at
(
idx
);
padding
.
at
(
2
+
idx
+
s
.
lens
().
size
())
=
op
.
padding
.
at
(
idx
);
}
// Default value needed for Max pooling
elements
=
m
.
insert_instruction
(
ins
,
make_op
(
"pad"
,
{{
"pads"
,
padding
},
{
"value"
,
std
::
numeric_limits
<
float
>::
lowest
()}}),
elements
);
}
for
(
auto
idx
{
0
};
idx
<
axis_indices
.
size
();
++
idx
)
{
migraphx
::
shape
s_indices
{
migraphx
::
shape
::
int32_type
,
{
axis_indices
.
at
(
idx
).
size
()}};
auto
indices
=
m
.
add_literal
(
migraphx
::
literal
{
s_indices
,
axis_indices
.
at
(
idx
)});
elements
=
m
.
insert_instruction
(
ins
,
make_op
(
"gather"
,
{{
"axis"
,
idx
+
2
/*ignore N,C*/
}}),
elements
,
indices
);
}
// Ignore padding
std
::
vector
<
size_t
>
new_padding
(
kernels
.
size
(),
0
);
// The kernel window elements are places next to each other. E.g. {x1, y1, x2, y2, ...}
// We need to skip them to not overlap
std
::
vector
<
size_t
>
new_strides
(
kernels
);
// Ignore dilations
std
::
vector
<
size_t
>
new_dilations
(
kernels
.
size
(),
1
);
m
.
replace_instruction
(
ins
,
make_op
(
"pooling"
,
{{
"mode"
,
op
.
mode
},
{
"padding"
,
new_padding
},
{
"stride"
,
new_strides
},
{
"lengths"
,
kernels
},
{
"dilations"
,
new_dilations
}}),
elements
);
}
void
rewrite_pooling
::
apply
(
module
&
m
)
const
{
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
"pooling"
)
continue
;
if
(
ins
->
inputs
().
empty
())
continue
;
auto
&&
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
bool
same_kernel_as_shape
=
std
::
equal
(
s
.
lens
().
cbegin
()
+
2
,
s
.
lens
().
cend
(),
op
.
lengths
.
cbegin
(),
op
.
lengths
.
cend
());
bool
default_strides
=
std
::
all_of
(
op
.
stride
.
cbegin
(),
op
.
stride
.
cend
(),
[](
auto
i
)
{
return
i
==
1
;
});
bool
default_padding
=
std
::
all_of
(
op
.
padding
.
cbegin
(),
op
.
padding
.
cend
(),
[](
auto
i
)
{
return
i
==
0
;
});
bool
default_dilations
=
std
::
all_of
(
op
.
dilations
.
cbegin
(),
op
.
dilations
.
cend
(),
[](
auto
i
)
{
return
i
==
1
;
});
if
(
same_kernel_as_shape
and
default_strides
and
default_padding
and
default_dilations
)
{
replace_with_reduce
(
m
,
ins
);
}
else
if
(
not
default_dilations
)
{
// Dilated AvgPool with padding is not supported
if
(
not
default_padding
and
op
.
mode
==
op
::
pooling_mode
::
average
)
{
continue
;
}
auto
size
=
std
::
accumulate
(
s
.
lens
().
cbegin
(),
s
.
lens
().
cend
(),
1
,
std
::
multiplies
<
size_t
>
());
// Can't handle too much size because of literal size
if
(
size
>
100000
)
{
continue
;
}
replace_dilations_with_gather_pooling
(
m
,
ins
);
}
}
}
}
}
...
...
src/schedule.cpp
View file @
6f768035
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/
simple_
par_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dom_info.hpp>
#include <migraphx/dom_info.hpp>
...
@@ -461,7 +461,7 @@ struct stream_info
...
@@ -461,7 +461,7 @@ struct stream_info
std
::
back_inserter
(
index_to_ins
),
std
::
back_inserter
(
index_to_ins
),
[](
auto
&&
it
)
{
return
it
.
first
;
});
[](
auto
&&
it
)
{
return
it
.
first
;
});
par_for
(
concur_ins
.
size
(),
[
&
](
auto
ins_index
,
auto
tid
)
{
simple_
par_for
(
concur_ins
.
size
(),
[
&
](
auto
ins_index
,
auto
tid
)
{
auto
merge_first
=
index_to_ins
[
ins_index
];
auto
merge_first
=
index_to_ins
[
ins_index
];
assert
(
concur_ins
.
count
(
merge_first
)
>
0
);
assert
(
concur_ins
.
count
(
merge_first
)
>
0
);
auto
&
merge_second
=
concur_ins
.
at
(
merge_first
);
auto
&
merge_second
=
concur_ins
.
at
(
merge_first
);
...
...
src/simplify_dyn_ops.cpp
View file @
6f768035
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
...
@@ -65,8 +66,65 @@ struct find_static_2in_broadcasts
...
@@ -65,8 +66,65 @@ struct find_static_2in_broadcasts
};
};
/**
/**
* Simplify slice with variable `starts` and `ends` to the constant version if
* Simplify slice with 2 inputs to the 1 input version if inputs[1] is constant.
* the `input_starts` and `input_ends` inputs are constant.
* From:
* slice(data, constant_input); two attributes set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct
find_const_2in_slice
{
auto
matcher
()
const
{
return
match
::
name
(
"slice"
)(
match
::
nargs
(
2
),
match
::
arg
(
1
)(
match
::
is_constant
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
inputs
=
ins
->
inputs
();
auto
slice_op
=
any_cast
<
op
::
slice
>
(
ins
->
get_operator
());
auto
set_attrs
=
slice_op
.
get_set_attributes
();
std
::
vector
<
int64_t
>
starts_vec
;
std
::
vector
<
int64_t
>
ends_vec
;
std
::
vector
<
int64_t
>
axes_vec
;
if
(
set_attrs
==
op
::
slice
::
ends_axes
)
{
// slice(data, starts)
inputs
.
at
(
1
)
->
eval
().
visit
(
[
&
](
auto
output
)
{
starts_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
ends_vec
=
slice_op
.
ends
;
axes_vec
=
slice_op
.
axes
;
}
else
if
(
set_attrs
==
op
::
slice
::
starts_axes
)
{
// slice(data, ends)
inputs
.
at
(
1
)
->
eval
().
visit
(
[
&
](
auto
output
)
{
ends_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
starts_vec
=
slice_op
.
starts
;
axes_vec
=
slice_op
.
axes
;
}
else
{
// slice(data, axes)
inputs
.
at
(
1
)
->
eval
().
visit
(
[
&
](
auto
output
)
{
axes_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
starts_vec
=
slice_op
.
starts
;
ends_vec
=
slice_op
.
ends
;
}
m
.
replace_instruction
(
ins
,
make_op
(
"slice"
,
{{
"starts"
,
starts_vec
},
{
"ends"
,
ends_vec
},
{
"axes"
,
axes_vec
}}),
inputs
.
at
(
0
));
}
};
/**
* Simplify slice with 3 inputs to the 1 input version if inputs[1:2] are constant.
* From:
* slice(data, constant_input1, constant_input2); one attribute set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
*/
struct
find_const_3in_slice
struct
find_const_3in_slice
{
{
...
@@ -81,27 +139,51 @@ struct find_const_3in_slice
...
@@ -81,27 +139,51 @@ struct find_const_3in_slice
{
{
auto
ins
=
mr
.
result
;
auto
ins
=
mr
.
result
;
auto
inputs
=
ins
->
inputs
();
auto
inputs
=
ins
->
inputs
();
argument
starts_arg
=
inputs
.
at
(
1
)
->
eval
();
auto
slice_op
=
any_cast
<
op
::
slice
>
(
ins
->
get_operator
());
argument
ends_arg
=
inputs
.
at
(
2
)
->
eval
();
auto
set_attrs
=
slice_op
.
get_set_attributes
();
if
(
not
starts_arg
.
empty
()
and
not
ends_arg
.
empty
())
{
std
::
vector
<
int64_t
>
starts_vec
;
std
::
vector
<
int64_t
>
starts_vec
;
std
::
vector
<
int64_t
>
ends_vec
;
std
::
vector
<
int64_t
>
ends_vec
;
starts_arg
.
visit
([
&
](
auto
output
)
{
starts_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int64_t
>
axes_vec
;
ends_arg
.
visit
([
&
](
auto
output
)
{
ends_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
if
(
set_attrs
==
op
::
slice
::
axes_only
)
auto
slice_val
=
ins
->
get_operator
().
to_value
();
{
auto
axes_vec
=
slice_val
.
at
(
"axes"
).
to_vector
<
int64_t
>
();
// slice(data, starts, ends)
inputs
.
at
(
1
)
->
eval
().
visit
(
[
&
](
auto
output
)
{
starts_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
inputs
.
at
(
2
)
->
eval
().
visit
(
[
&
](
auto
output
)
{
ends_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
axes_vec
=
slice_op
.
axes
;
}
else
if
(
set_attrs
==
op
::
slice
::
ends_only
)
{
// slice(data, starts, axes)
inputs
.
at
(
1
)
->
eval
().
visit
(
[
&
](
auto
output
)
{
starts_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
inputs
.
at
(
2
)
->
eval
().
visit
(
[
&
](
auto
output
)
{
axes_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
ends_vec
=
slice_op
.
ends
;
}
else
{
// slice(data, ends, axes)
inputs
.
at
(
1
)
->
eval
().
visit
(
[
&
](
auto
output
)
{
ends_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
inputs
.
at
(
2
)
->
eval
().
visit
(
[
&
](
auto
output
)
{
axes_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
starts_vec
=
slice_op
.
starts
;
}
m
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
ins
,
make_op
(
"slice"
,
{{
"starts"
,
starts_vec
},
{
"ends"
,
ends_vec
},
{
"axes"
,
axes_vec
}}),
make_op
(
"slice"
,
{{
"starts"
,
starts_vec
},
{
"ends"
,
ends_vec
},
{
"axes"
,
axes_vec
}}),
inputs
.
at
(
0
));
inputs
.
at
(
0
));
}
}
}
};
};
/**
/**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if
* Simplify slice with 4 inputs to the 1 input version if inputs[1:3] are constant.
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant.
* From:
* slice(data, constant_starts, constant_ends, constant_axes)
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
*/
struct
find_const_4in_slice
struct
find_const_4in_slice
{
{
...
@@ -117,9 +199,9 @@ struct find_const_4in_slice
...
@@ -117,9 +199,9 @@ struct find_const_4in_slice
{
{
auto
ins
=
mr
.
result
;
auto
ins
=
mr
.
result
;
auto
inputs
=
ins
->
inputs
();
auto
inputs
=
ins
->
inputs
();
argument
starts_arg
=
inputs
.
at
(
1
)
->
eval
();
argument
starts_arg
=
inputs
.
at
(
1
)
->
eval
(
false
);
argument
ends_arg
=
inputs
.
at
(
2
)
->
eval
();
argument
ends_arg
=
inputs
.
at
(
2
)
->
eval
(
false
);
argument
axes_arg
=
inputs
.
at
(
3
)
->
eval
();
argument
axes_arg
=
inputs
.
at
(
3
)
->
eval
(
false
);
if
(
not
starts_arg
.
empty
()
and
not
ends_arg
.
empty
()
and
not
axes_arg
.
empty
())
if
(
not
starts_arg
.
empty
()
and
not
ends_arg
.
empty
()
and
not
axes_arg
.
empty
())
{
{
std
::
vector
<
int64_t
>
starts_vec
;
std
::
vector
<
int64_t
>
starts_vec
;
...
@@ -179,6 +261,7 @@ struct find_static_dimensions_of
...
@@ -179,6 +261,7 @@ struct find_static_dimensions_of
/**
/**
* Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1
* Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1
* argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes.
* argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes.
* This matcher can be generalized to matching reshape(data, static_shape_output_tensor).
* From:
* From:
* x = allocate(constant_output_dims) -> reshape(data, x)
* x = allocate(constant_output_dims) -> reshape(data, x)
* To:
* To:
...
@@ -207,14 +290,44 @@ struct find_const_alloc_reshapes
...
@@ -207,14 +290,44 @@ struct find_const_alloc_reshapes
}
}
};
};
/**
* Simplify allocate into fill operator that has constant output dimensions and constant value.
* The allocate into fill instructions is what is produced when parsing the ONNX
* ConstantOfShape operator. This replacement could be handled with propagate_constant, but
* would rather have the simplification happen earlier during compiling.
* This matcher can be generalized to matching fill(constant_value, static_shape_output_tensor).
* From:
* x = allocate(constant_ouptut_dims) -> fill(constant_value, x)
* To:
* literal
*/
struct
find_const_alloc_fill
{
auto
matcher
()
const
{
return
match
::
name
(
"fill"
)(
match
::
arg
(
0
)(
match
::
is_constant
()),
match
::
arg
(
1
)(
match
::
name
(
"allocate"
)(
match
::
is_constant
())));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
fill_ins
=
mr
.
result
;
auto
fill_arg
=
fill_ins
->
eval
(
false
);
auto
l
=
m
.
add_literal
(
fill_arg
.
get_shape
(),
fill_arg
.
data
());
m
.
replace_instruction
(
fill_ins
,
l
);
}
};
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
find_static_dimensions_of
{},
find_static_dimensions_of
{},
find_const_alloc_reshapes
{},
find_const_alloc_reshapes
{},
find_static_2in_broadcasts
{},
find_static_2in_broadcasts
{},
find_const_2in_slice
{},
find_const_3in_slice
{},
find_const_3in_slice
{},
find_const_4in_slice
{});
find_const_4in_slice
{},
find_const_alloc_fill
{});
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/cpu/dnnl.cpp
View file @
6f768035
...
@@ -67,8 +67,8 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t)
...
@@ -67,8 +67,8 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t)
case
st
::
float_type
:
return
dt
::
f32
;
case
st
::
float_type
:
return
dt
::
f32
;
case
st
::
int32_type
:
return
dt
::
s32
;
case
st
::
int32_type
:
return
dt
::
s32
;
case
st
::
int8_type
:
return
dt
::
s8
;
case
st
::
int8_type
:
return
dt
::
s8
;
case
st
::
uint8_type
:
case
st
::
uint8_type
:
return
dt
::
u8
;
case
st
::
fp8e4m3fnuz_type
:
return
dt
::
u8
;
case
st
::
fp8e4m3fnuz_type
:
MIGRAPHX_THROW
(
"fp8e4m3fnuz unsupported in DNNL"
)
;
default:
MIGRAPHX_THROW
(
"Unsupported data type"
);
default:
MIGRAPHX_THROW
(
"Unsupported data type"
);
}
}
}
}
...
...
src/targets/cpu/lowering.cpp
View file @
6f768035
...
@@ -340,7 +340,6 @@ struct cpu_apply
...
@@ -340,7 +340,6 @@ struct cpu_apply
{
"reduce_min"
,
"reduction_min"
},
{
"reduce_min"
,
"reduction_min"
},
{
"reduce_sum"
,
"reduction_sum"
},
{
"reduce_sum"
,
"reduction_sum"
},
});
});
extend_op
(
"concat"
,
"dnnl::concat"
);
extend_op
(
"concat"
,
"dnnl::concat"
);
extend_op
(
"contiguous"
,
"dnnl::reorder"
);
extend_op
(
"contiguous"
,
"dnnl::reorder"
);
extend_op
(
"convolution"
,
"dnnl::convolution"
);
extend_op
(
"convolution"
,
"dnnl::convolution"
);
...
@@ -376,6 +375,12 @@ struct cpu_apply
...
@@ -376,6 +375,12 @@ struct cpu_apply
// Apply these operators first so the inputs can be const folded
// Apply these operators first so the inputs can be const folded
for
(
auto
it
:
iterator_for
(
*
modl
))
for
(
auto
it
:
iterator_for
(
*
modl
))
{
{
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if
(
std
::
any_of
(
it
->
inputs
().
begin
(),
it
->
inputs
().
end
(),
[](
const
auto
&
i
)
{
return
i
->
get_shape
().
type
()
==
migraphx
::
shape
::
fp8e4m3fnuz_type
;
}))
continue
;
if
(
it
->
name
()
==
"pow"
)
if
(
it
->
name
()
==
"pow"
)
{
{
apply_pow
(
it
);
apply_pow
(
it
);
...
@@ -383,6 +388,12 @@ struct cpu_apply
...
@@ -383,6 +388,12 @@ struct cpu_apply
}
}
for
(
auto
it
:
iterator_for
(
*
modl
))
for
(
auto
it
:
iterator_for
(
*
modl
))
{
{
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if
(
std
::
any_of
(
it
->
inputs
().
begin
(),
it
->
inputs
().
end
(),
[](
const
auto
&
i
)
{
return
i
->
get_shape
().
type
()
==
migraphx
::
shape
::
fp8e4m3fnuz_type
;
}))
continue
;
if
(
it
->
name
()
==
"pooling"
)
if
(
it
->
name
()
==
"pooling"
)
{
{
apply_pooling
(
it
);
apply_pooling
(
it
);
...
...
src/targets/cpu/pooling.cpp
View file @
6f768035
...
@@ -34,23 +34,32 @@ namespace migraphx {
...
@@ -34,23 +34,32 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
cpu
{
namespace
cpu
{
struct
dnnl_pooling
:
dnnl_extend_op
<
dnnl_pooling
,
dnnl
::
pooling_forward
,
op
::
pooling
>
struct
dnnl_pooling
:
dnnl_extend_op
<
dnnl_pooling
,
dnnl
::
pooling_
v2_
forward
,
op
::
pooling
>
{
{
std
::
vector
<
int
>
arg_map
(
int
)
const
{
return
{
MIGRAPHX_DNNL_PREFIX
(
ARG_SRC
)};
}
std
::
vector
<
int
>
arg_map
(
int
)
const
{
return
{
MIGRAPHX_DNNL_PREFIX
(
ARG_SRC
)};
}
dnnl
::
pooling_forward
::
desc
get_desc
(
const
std
::
unordered_map
<
int
,
dnnl
::
memory
::
desc
>&
m
)
const
dnnl
::
pooling_v2_forward
::
desc
get_desc
(
const
std
::
unordered_map
<
int
,
dnnl
::
memory
::
desc
>&
m
)
const
{
{
auto
algo
=
op
.
mode
==
op
::
pooling_mode
::
max
?
dnnl
::
algorithm
::
pooling_max
auto
algo
=
op
.
mode
==
op
::
pooling_mode
::
max
?
dnnl
::
algorithm
::
pooling_max
:
dnnl
::
algorithm
::
pooling_avg
;
:
dnnl
::
algorithm
::
pooling_avg
;
auto
kdims
=
op
.
kdims
();
auto
kdims
=
op
.
kdims
();
std
::
vector
<
size_t
>
padding_l
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
);
std
::
vector
<
size_t
>
padding_l
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
);
std
::
vector
<
size_t
>
padding_r
(
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
());
std
::
vector
<
size_t
>
padding_r
(
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
());
// Note: It is not documented, but the default dilation seems to be 0 instead of 1.
// We need to offset dilations with -1.
std
::
vector
<
size_t
>
dilations
;
std
::
transform
(
op
.
dilations
.
cbegin
(),
op
.
dilations
.
cend
(),
std
::
back_inserter
(
dilations
),
[](
size_t
d
)
{
return
d
-
1
;
});
return
{
dnnl
::
prop_kind
::
forward_inference
,
return
{
dnnl
::
prop_kind
::
forward_inference
,
algo
,
algo
,
m
.
at
(
MIGRAPHX_DNNL_PREFIX
(
ARG_SRC
)),
m
.
at
(
MIGRAPHX_DNNL_PREFIX
(
ARG_SRC
)),
m
.
at
(
MIGRAPHX_DNNL_PREFIX
(
ARG_DST
)),
m
.
at
(
MIGRAPHX_DNNL_PREFIX
(
ARG_DST
)),
to_dnnl_dims
(
op
.
stride
),
to_dnnl_dims
(
op
.
stride
),
to_dnnl_dims
(
op
.
lengths
),
to_dnnl_dims
(
op
.
lengths
),
to_dnnl_dims
(
dilations
),
to_dnnl_dims
(
padding_l
),
to_dnnl_dims
(
padding_l
),
to_dnnl_dims
(
padding_r
)};
to_dnnl_dims
(
padding_r
)};
}
}
...
...
src/targets/gpu/CMakeLists.txt
View file @
6f768035
...
@@ -126,7 +126,6 @@ add_library(migraphx_gpu
...
@@ -126,7 +126,6 @@ add_library(migraphx_gpu
fuse_ck.cpp
fuse_ck.cpp
fuse_mlir.cpp
fuse_mlir.cpp
fuse_ops.cpp
fuse_ops.cpp
gather.cpp
gemm_impl.cpp
gemm_impl.cpp
hip.cpp
hip.cpp
kernel.cpp
kernel.cpp
...
@@ -140,7 +139,6 @@ add_library(migraphx_gpu
...
@@ -140,7 +139,6 @@ add_library(migraphx_gpu
nonzero.cpp
nonzero.cpp
pack_args.cpp
pack_args.cpp
prefuse_ops.cpp
prefuse_ops.cpp
pad.cpp
perfdb.cpp
perfdb.cpp
pooling.cpp
pooling.cpp
reverse.cpp
reverse.cpp
...
@@ -168,12 +166,10 @@ endfunction()
...
@@ -168,12 +166,10 @@ endfunction()
register_migraphx_gpu_ops
(
hip_
register_migraphx_gpu_ops
(
hip_
argmax
argmax
argmin
argmin
gather
logsoftmax
logsoftmax
loop
loop
multinomial
multinomial
nonzero
nonzero
pad
prefix_scan_sum
prefix_scan_sum
reverse
reverse
scatter
scatter
...
...
src/targets/gpu/compile_hip.cpp
View file @
6f768035
...
@@ -194,7 +194,7 @@ struct hiprtc_program
...
@@ -194,7 +194,7 @@ struct hiprtc_program
};
};
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
hiprtc_src_file
>
srcs
,
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
hiprtc_src_file
>
srcs
,
std
::
string
params
,
const
std
::
string
&
params
,
const
std
::
string
&
arch
)
const
std
::
string
&
arch
)
{
{
hiprtc_program
prog
(
std
::
move
(
srcs
));
hiprtc_program
prog
(
std
::
move
(
srcs
));
...
@@ -238,8 +238,9 @@ bool hip_has_flags(const std::vector<std::string>& flags)
...
@@ -238,8 +238,9 @@ bool hip_has_flags(const std::vector<std::string>& flags)
}
}
}
}
std
::
vector
<
std
::
vector
<
char
>>
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
const
std
::
string
&
params
,
const
std
::
string
&
arch
)
{
{
std
::
vector
<
hiprtc_src_file
>
hsrcs
{
srcs
.
begin
(),
srcs
.
end
()};
std
::
vector
<
hiprtc_src_file
>
hsrcs
{
srcs
.
begin
(),
srcs
.
end
()};
if
(
enabled
(
MIGRAPHX_GPU_DUMP_SRC
{}))
if
(
enabled
(
MIGRAPHX_GPU_DUMP_SRC
{}))
...
@@ -281,13 +282,13 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
...
@@ -281,13 +282,13 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if
(
fs
::
exists
(
out
))
if
(
fs
::
exists
(
out
))
return
{
read_buffer
(
out
.
string
())};
return
{
read_buffer
(
out
.
string
())};
}
}
return
compile_hip_src_with_hiprtc
(
std
::
move
(
hsrcs
),
std
::
move
(
params
)
,
arch
);
return
compile_hip_src_with_hiprtc
(
std
::
move
(
hsrcs
),
params
,
arch
);
}
}
#else // MIGRAPHX_USE_HIPRTC
#else // MIGRAPHX_USE_HIPRTC
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
hiprtc_src_file
>
,
// NOLINT
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
hiprtc_src_file
>
,
// NOLINT
std
::
string
,
// NOLINT
const
std
::
string
&
,
// NOLINT
const
std
::
string
&
)
const
std
::
string
&
)
{
{
MIGRAPHX_THROW
(
"Not using hiprtc"
);
MIGRAPHX_THROW
(
"Not using hiprtc"
);
...
@@ -316,29 +317,15 @@ src_compiler assemble(src_compiler compiler)
...
@@ -316,29 +317,15 @@ src_compiler assemble(src_compiler compiler)
return
compiler
;
return
compiler
;
}
}
std
::
vector
<
std
::
vector
<
char
>>
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
const
std
::
string
&
params
,
const
std
::
string
&
arch
)
{
{
assert
(
not
srcs
.
empty
());
assert
(
not
srcs
.
empty
());
if
(
not
is_hip_clang_compiler
())
if
(
not
is_hip_clang_compiler
())
MIGRAPHX_THROW
(
"Unknown hip compiler: "
MIGRAPHX_HIP_COMPILER
);
MIGRAPHX_THROW
(
"Unknown hip compiler: "
MIGRAPHX_HIP_COMPILER
);
if
(
params
.
find
(
"-std="
)
==
std
::
string
::
npos
)
params
+=
" --std=c++17"
;
params
+=
" -fno-gpu-rdc"
;
if
(
enabled
(
MIGRAPHX_GPU_DEBUG_SYM
{}))
params
+=
" -g"
;
params
+=
" -c"
;
params
+=
" --offload-arch="
+
arch
;
params
+=
" --cuda-device-only"
;
params
+=
" -O"
+
string_value_of
(
MIGRAPHX_GPU_OPTIMIZE
{},
"3"
)
+
" "
;
if
(
enabled
(
MIGRAPHX_GPU_DEBUG
{}))
params
+=
" -DMIGRAPHX_DEBUG"
;
params
+=
" -Wno-unused-command-line-argument -Wno-cuda-compat "
;
params
+=
MIGRAPHX_HIP_COMPILER_FLAGS
;
src_compiler
compiler
;
src_compiler
compiler
;
compiler
.
flags
=
params
;
compiler
.
flags
=
params
;
compiler
.
compiler
=
MIGRAPHX_HIP_COMPILER
;
compiler
.
compiler
=
MIGRAPHX_HIP_COMPILER
;
...
@@ -346,6 +333,23 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
...
@@ -346,6 +333,23 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if
(
has_compiler_launcher
())
if
(
has_compiler_launcher
())
compiler
.
launcher
=
MIGRAPHX_HIP_COMPILER_LAUNCHER
;
compiler
.
launcher
=
MIGRAPHX_HIP_COMPILER_LAUNCHER
;
#endif
#endif
if
(
params
.
find
(
"-std="
)
==
std
::
string
::
npos
)
compiler
.
flags
+=
" --std=c++17"
;
compiler
.
flags
+=
" -fno-gpu-rdc"
;
if
(
enabled
(
MIGRAPHX_GPU_DEBUG_SYM
{}))
compiler
.
flags
+=
" -g"
;
compiler
.
flags
+=
" -c"
;
compiler
.
flags
+=
" --offload-arch="
+
arch
;
compiler
.
flags
+=
" --cuda-device-only"
;
compiler
.
flags
+=
" -O"
+
string_value_of
(
MIGRAPHX_GPU_OPTIMIZE
{},
"3"
)
+
" "
;
if
(
enabled
(
MIGRAPHX_GPU_DEBUG
{}))
compiler
.
flags
+=
" -DMIGRAPHX_DEBUG"
;
compiler
.
flags
+=
" -Wno-unused-command-line-argument -Wno-cuda-compat "
;
compiler
.
flags
+=
MIGRAPHX_HIP_COMPILER_FLAGS
;
if
(
enabled
(
MIGRAPHX_GPU_DUMP_SRC
{}))
if
(
enabled
(
MIGRAPHX_GPU_DUMP_SRC
{}))
{
{
for
(
const
auto
&
src
:
srcs
)
for
(
const
auto
&
src
:
srcs
)
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
6f768035
...
@@ -200,7 +200,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
...
@@ -200,7 +200,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options
.
params
+=
" "
+
join_strings
(
compiler_warnings
(),
" "
);
options
.
params
+=
" "
+
join_strings
(
compiler_warnings
(),
" "
);
options
.
params
+=
" -ftemplate-backtrace-limit=0"
;
options
.
params
+=
" -ftemplate-backtrace-limit=0"
;
options
.
params
+=
" -Werror"
;
options
.
params
+=
" -Werror"
;
auto
cos
=
compile_hip_src
(
srcs
,
std
::
move
(
options
.
params
)
,
get_device_name
());
auto
cos
=
compile_hip_src
(
srcs
,
options
.
params
,
get_device_name
());
if
(
cos
.
size
()
!=
1
)
if
(
cos
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"No code object"
);
MIGRAPHX_THROW
(
"No code object"
);
return
code_object_op
{
value
::
binary
{
cos
.
front
()},
return
code_object_op
{
value
::
binary
{
cos
.
front
()},
...
...
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
View file @
6f768035
...
@@ -43,24 +43,32 @@ template <index_int N,
...
@@ -43,24 +43,32 @@ template <index_int N,
__device__
void
block_scan
(
index
idx
,
Op
op
,
T
init
,
ForStride
fs
,
Input
input
,
Output
output
)
__device__
void
block_scan
(
index
idx
,
Op
op
,
T
init
,
ForStride
fs
,
Input
input
,
Output
output
)
{
{
using
type
=
decltype
(
input
(
deduce_for_stride
(
fs
)));
using
type
=
decltype
(
input
(
deduce_for_stride
(
fs
)));
MIGRAPHX_DEVICE_SHARED
type
buffer
[
N
];
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2
][
N
];
type
x
=
init
;
type
x
=
init
;
fs
([
&
](
auto
i
)
{
fs
([
&
](
auto
i
)
{
index_int
iout
=
0
;
index_int
iin
=
1
;
if
(
idx
.
local
==
0
)
if
(
idx
.
local
==
0
)
buffer
[
idx
.
local
]
=
op
(
input
(
i
),
x
);
buffer
[
iout
][
idx
.
local
]
=
op
(
input
(
i
),
x
);
else
else
buffer
[
idx
.
local
]
=
input
(
i
);
buffer
[
iout
][
idx
.
local
]
=
input
(
i
);
__syncthreads
();
__syncthreads
();
for
(
index_int
s
=
1
;
s
<
idx
.
nlocal
();
s
*=
2
)
for
(
index_int
s
=
1
;
s
<
idx
.
nlocal
();
s
*=
2
)
{
{
if
(
idx
.
local
+
s
<
idx
.
nlocal
())
iout
=
1
-
iout
;
iin
=
1
-
iin
;
if
(
idx
.
local
>=
s
)
{
{
buffer
[
idx
.
local
+
s
]
=
op
(
buffer
[
idx
.
local
],
buffer
[
idx
.
local
+
s
]);
buffer
[
iout
][
idx
.
local
]
=
op
(
buffer
[
iin
][
idx
.
local
],
buffer
[
iin
][
idx
.
local
-
s
]);
}
else
{
buffer
[
iout
][
idx
.
local
]
=
buffer
[
iin
][
idx
.
local
];
}
}
__syncthreads
();
__syncthreads
();
}
}
x
=
buffer
[
idx
.
nlocal
()
-
1
];
x
=
buffer
[
iout
][
idx
.
nlocal
()
-
1
];
output
(
i
,
buffer
[
idx
.
local
]);
output
(
i
,
buffer
[
iout
][
idx
.
local
]);
});
});
}
}
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment