Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8d32c6b8
Commit
8d32c6b8
authored
Oct 17, 2023
by
Paul
Browse files
Merge branch 'develop' into blas_tuning
parents
23cb7917
f25606f9
Changes
386
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1288 additions
and
185 deletions
+1288
-185
src/onnx/parse_qlinearconv.cpp
src/onnx/parse_qlinearconv.cpp
+241
-0
src/onnx/parse_qlinearglavgpool.cpp
src/onnx/parse_qlinearglavgpool.cpp
+151
-0
src/onnx/parse_qlinearmatmul.cpp
src/onnx/parse_qlinearmatmul.cpp
+198
-0
src/onnx/parse_reshape.cpp
src/onnx/parse_reshape.cpp
+16
-7
src/onnx/parse_resize.cpp
src/onnx/parse_resize.cpp
+18
-23
src/onnx/parse_roialign.cpp
src/onnx/parse_roialign.cpp
+7
-4
src/onnx/parse_slice.cpp
src/onnx/parse_slice.cpp
+82
-49
src/onnx/parse_spacetodepth.cpp
src/onnx/parse_spacetodepth.cpp
+1
-2
src/optimize_module.cpp
src/optimize_module.cpp
+7
-3
src/pad_calc.cpp
src/pad_calc.cpp
+40
-1
src/process.cpp
src/process.cpp
+167
-1
src/program.cpp
src/program.cpp
+2
-2
src/propagate_constant.cpp
src/propagate_constant.cpp
+4
-4
src/py/CMakeLists.txt
src/py/CMakeLists.txt
+14
-17
src/quantization.cpp
src/quantization.cpp
+5
-4
src/rewrite_pooling.cpp
src/rewrite_pooling.cpp
+5
-16
src/shape.cpp
src/shape.cpp
+23
-16
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+118
-34
src/simplify_dyn_ops.cpp
src/simplify_dyn_ops.cpp
+141
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+48
-2
No files found.
src/onnx/parse_qlinearconv.cpp
0 → 100644
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/onnx/conv.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
/*
*********************************************************************************
* Reference: see QLinearConv in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
com.microsoft.QLinearConv
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
ATTRIBUTES:
auto_pad : string
channels_last : int
dilations : list of ints
group : int
kernel_shape : list of ints
pads : list of ints
strides : list of ints
INPUTS (8 - 9):
x : T1
x_scale : tensor(float)
x_zero_point : T1
w : T2
w_scale : tensor(float)
w_zero_point : T2
y_scale : tensor(float)
y_zero_point : T3
B (optional) : T4
OUTPUTS:
y : T3
Type Constraints:
T1 : tensor(int8), tensor(uint8)
T2 : tensor(int8), tensor(uint8)
T3 : tensor(int8), tensor(uint8)
T4 : tensor(int32)
More details also at:
https://xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__QLinearConv.html
*/
struct
parse_qlinearconv
:
op_parser
<
parse_qlinearconv
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"QLinearConv"
}};
}
// basic type checking for QLinearConv Operator
void
check_inputs
(
const
std
::
vector
<
instruction_ref
>&
inp_arg
)
const
{
if
(
inp_arg
.
size
()
<
8
)
MIGRAPHX_THROW
(
"QLINEARCONV: missing inputs"
);
const
instruction_ref
&
in_x
=
inp_arg
[
0
];
const
instruction_ref
&
in_scale_x
=
inp_arg
[
1
];
const
instruction_ref
&
in_w
=
inp_arg
[
3
];
const
instruction_ref
&
in_scale_w
=
inp_arg
[
4
];
const
instruction_ref
&
in_scale_y
=
inp_arg
[
6
];
auto
sh_x
=
in_x
->
get_shape
();
auto
sh_w
=
in_w
->
get_shape
();
auto
type_x
=
sh_x
.
type
();
auto
type_w
=
sh_w
.
type
();
assert
(
in_x
->
get_shape
().
ndim
()
>
2
);
if
(
type_x
!=
shape
::
int8_type
and
type_x
!=
shape
::
uint8_type
)
MIGRAPHX_THROW
(
"QLINEARCONV: unsupported input type"
);
if
(
type_w
!=
shape
::
int8_type
and
type_w
!=
shape
::
uint8_type
)
MIGRAPHX_THROW
(
"QLINEARCONV: unsupported weight type"
);
if
(
in_scale_x
->
get_shape
().
type
()
!=
shape
::
float_type
)
MIGRAPHX_THROW
(
"QLINEARCONV x scale type should be float"
);
if
(
in_scale_w
->
get_shape
().
type
()
!=
shape
::
float_type
)
MIGRAPHX_THROW
(
"QLINEARCONV: wt scale type should be float"
);
if
(
in_scale_y
->
get_shape
().
type
()
!=
shape
::
float_type
)
MIGRAPHX_THROW
(
"QLINEARCONV: y scale type should be float"
);
if
(
inp_arg
.
size
()
>
8
and
inp_arg
[
8
]
->
get_shape
().
type
()
!=
shape
::
int32_type
)
MIGRAPHX_THROW
(
"QLINEARCONV y bias should be int32"
);
}
// process all attributes of QLinearConv Operator..
value
process_attributes
(
const
onnx_parser
&
parser
,
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
value
values
;
const
auto
&
in_x
=
args
[
0
];
const
auto
&
wt
=
args
[
3
];
size_t
kdims
=
in_x
->
get_shape
().
ndim
()
-
2
;
check_padding_mode
(
info
,
"QLINEARCONV"
);
values
[
"stride"
]
=
std
::
vector
<
int
>
(
kdims
,
1
);
values
[
"dilation"
]
=
std
::
vector
<
int
>
(
kdims
,
1
);
values
[
"padding"
]
=
std
::
vector
<
int
>
(
kdims
,
0
);
values
[
"group"
]
=
1
;
if
(
contains
(
info
.
attributes
,
"group"
))
values
[
"group"
]
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"group"
)).
template
at
<
int
>();
if
(
contains
(
info
.
attributes
,
"strides"
))
{
std
::
vector
<
int
>
st
;
copy
(
info
.
attributes
.
at
(
"strides"
).
ints
(),
std
::
back_inserter
(
st
));
check_attr_sizes
(
kdims
,
st
.
size
(),
"QLINEARCONV: inconsistent strides"
);
values
[
"stride"
]
=
st
;
}
if
(
contains
(
info
.
attributes
,
"dilations"
))
{
std
::
vector
<
int
>
dil
;
copy
(
info
.
attributes
.
at
(
"dilations"
).
ints
(),
std
::
back_inserter
(
dil
));
check_attr_sizes
(
kdims
,
dil
.
size
(),
"QLINEARCONV: inconsistent dilations"
);
values
[
"dilation"
]
=
dil
;
}
if
(
contains
(
info
.
attributes
,
"pads"
))
{
std
::
vector
<
int
>
pads
;
copy
(
info
.
attributes
.
at
(
"pads"
).
ints
(),
std
::
back_inserter
(
pads
));
check_attr_sizes
(
kdims
,
pads
.
size
()
/
2
,
"QLINEARCONV: inconsistent padding"
);
values
[
"padding"
]
=
pads
;
}
else
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
{
auto
in_lens
=
in_x
->
get_shape
().
lens
();
auto
wt_lens
=
wt
->
get_shape
().
lens
();
std
::
vector
<
std
::
size_t
>
k_lens
(
wt_lens
.
begin
()
+
2
,
wt_lens
.
end
());
std
::
vector
<
int64_t
>
pads
=
values
[
"padding"
].
to_vector
<
std
::
int64_t
>
();
cal_auto_padding_size
(
info
,
values
,
k_lens
,
values
[
"dilation"
].
to_vector
<
std
::
size_t
>
(),
in_lens
,
pads
);
values
[
"padding"
]
=
pads
;
}
recalc_conv_attributes
(
values
,
kdims
);
return
values
;
}
instruction_ref
add_bias_to_conv
(
const
instruction_ref
bias_arg
,
const
instruction_ref
conv_instr
,
const
onnx_parser
::
node_info
&
info
)
const
{
auto
conv_sh
=
conv_instr
->
get_shape
();
auto
conv_lens
=
conv_sh
.
lens
();
auto
conv_type
=
conv_sh
.
type
();
auto
broadcast_bias
=
info
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
conv_lens
}}),
bias_arg
);
auto
f_bias
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
conv_type
}}),
broadcast_bias
);
return
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
conv_instr
,
f_bias
);
};
instruction_ref
parse
(
const
op_desc
&
/* opd */
,
const
onnx_parser
&
parser
,
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
check_inputs
(
args
);
auto
values
=
process_attributes
(
parser
,
info
,
args
);
// input: quantized x, scale, zero_pt
const
instruction_ref
&
in_x
=
args
[
0
];
const
instruction_ref
&
in_scale_x
=
args
[
1
];
const
instruction_ref
&
in_zero_pt_x
=
args
[
2
];
// input: quantized weights, scale, zero_pt
const
instruction_ref
&
in_w
=
args
[
3
];
const
instruction_ref
&
in_scale_w
=
args
[
4
];
const
instruction_ref
&
in_zero_pt_w
=
args
[
5
];
// for the dequantized output y: scale & zero_pt
const
instruction_ref
&
in_scale_y
=
args
[
6
];
const
instruction_ref
&
in_zero_pt_y
=
args
[
7
];
auto
dquant_x
=
bcast_qdq_instr
(
"dequantizelinear"
,
in_x
,
in_scale_x
,
in_zero_pt_x
,
info
);
auto
dquant_w
=
bcast_qdq_instr
(
"dequantizelinear"
,
in_w
,
in_scale_w
,
in_zero_pt_w
,
info
);
auto
conv_op
=
migraphx
::
make_op
(
"convolution"
,
values
);
auto
conv_x_w
=
info
.
add_instruction
(
conv_op
,
dquant_x
,
dquant_w
);
// Biases, if any.. : is an optional argument.
if
(
args
.
size
()
>
8
)
conv_x_w
=
add_bias_to_conv
(
args
[
8
],
conv_x_w
,
info
);
auto
quant_conv
=
bcast_qdq_instr
(
"quantizelinear"
,
conv_x_w
,
in_scale_y
,
in_zero_pt_y
,
info
);
return
quant_conv
;
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_qlinearglavgpool.cpp
0 → 100644
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
/*
*********************************************************************************
* Reference: see QLinearGlobalAveragePool in *
* github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
QLinearGlobalAveragePool consumes an input tensor X and applies
Average pooling across the values in the same channel. This is
equivalent to AveragePool with kernel size equal to the spatial
dimension of input tensor. Input is of type uint8_t or int8_t.
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
Attributes
channels_last : int
Inputs
X : T
Input data tensor from the previous operator; According to channels_last, dimensions for image case
are (N x C x H x W), or (N x H x W x C) where N is the batch size, C is the number of channels, and
H and W are the height and the width of the data. For non image case, the dimensions are in the form
of (N x C x D1 x D2 ... Dn), or (N x D1 X D2 ... Dn x C) where N is the batch size.
x_scale : tensor(float)
Scale of quantized input 'X'. It must be a scalar.
x_zero_point : T
Zero point tensor for input 'X'. It must be a scalar.
y_scale : tensor(float)
Scale of quantized output 'Y'. It must be a scalar.
y_zero_point : T
Zero point tensor for output 'Y'. It must be a scalar.
Outputs
Y : T
Output data tensor from pooling across the input tensor. The output tensor has the same rank as the
input. with the N and C value keep it value, while the other dimensions are all 1. Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to signed/unsigned int8 tensors.
*/
struct
parse_qlinearglobalaveragepool
:
op_parser
<
parse_qlinearglobalaveragepool
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"QLinearGlobalAveragePool"
}};
}
// basic type checking for QLinearGlobalAveragePool Operator
void
check_inputs
(
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
if
(
args
.
size
()
<
5
)
MIGRAPHX_THROW
(
"QLINEARGLOBALAVERAGEPOOL: missing inputs"
);
const
auto
&
in_x
=
args
[
0
];
const
auto
&
zero_pt_x
=
args
[
2
];
const
auto
&
zero_pt_y
=
args
[
4
];
if
(
in_x
->
get_shape
().
ndim
()
<=
2
)
MIGRAPHX_THROW
(
"QLINEARGLOBALAVERAGEPOOL: input dimensions too small"
);
auto
type_x
=
in_x
->
get_shape
().
type
();
if
(
type_x
!=
migraphx
::
shape
::
int8_type
and
type_x
!=
migraphx
::
shape
::
uint8_type
)
MIGRAPHX_THROW
(
"QLINEARGLOBALAVERAGEPOOL: unsupported input type"
);
if
(
type_x
!=
zero_pt_x
->
get_shape
().
type
())
MIGRAPHX_THROW
(
"QLINEARGLOBALAVERAGEPOOL: mismatched type: input zero point"
);
if
(
type_x
!=
zero_pt_y
->
get_shape
().
type
())
MIGRAPHX_THROW
(
"QLINEARGLOBALAVERAGEPOOL: mismatched type: output zero point"
);
}
instruction_ref
parse
(
const
op_desc
&
/* opd */
,
const
onnx_parser
&
parser
,
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
int
channels_last
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"channels_last"
)).
template
at
<
int
>();
if
(
channels_last
!=
0
)
MIGRAPHX_THROW
(
"QLINEARGLOBALAVERAGEPOOL: channels_last (N x D1..Dn x C) is not supported"
);
check_inputs
(
args
);
// Input: X
const
auto
&
in_x
=
args
[
0
];
const
auto
&
scale_x
=
args
[
1
];
const
auto
&
zero_pt_x
=
args
[
2
];
auto
dquant_x
=
bcast_qdq_instr
(
"dequantizelinear"
,
in_x
,
scale_x
,
zero_pt_x
,
info
);
// Output Y = globalaveragepool(X)
auto
op
=
migraphx
::
op
::
pooling
{
migraphx
::
op
::
pooling_mode
::
average
};
auto
lens
=
in_x
->
get_shape
().
lens
();
std
::
vector
<
size_t
>
lengths
(
lens
.
begin
()
+
2
,
lens
.
end
());
op
.
lengths
=
lengths
;
op
.
padding
=
std
::
vector
<
size_t
>
(
lens
.
size
());
auto
out_y
=
info
.
add_instruction
(
op
,
dquant_x
);
const
auto
&
scale_y
=
args
[
3
];
const
auto
&
zero_pt_y
=
args
[
4
];
auto
out_quant_y
=
bcast_qdq_instr
(
"quantizelinear"
,
out_y
,
scale_y
,
zero_pt_y
,
info
);
return
out_quant_y
;
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_qlinearmatmul.cpp
0 → 100644
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
/*
*********************************************************************************
* Reference: see QLinearMatMul in *
* https://onnx.ai/onnx/operators/onnx__QLinearMatMul.html *
*********************************************************************************
Matrix product that behaves like numpy.matmul:
https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. It consumes two
quantized input tensors, their scales and zero points, scale and zero point of output, and computes
the quantized output. The quantization formula is y = saturate((x / y_scale) + y_zero_point). For (x
/ y_scale), it is rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding
for details. Scale and zero point must have same shape. They must be either scalar (per tensor) or
N-D tensor (per row for ‘a’ and per column for ‘b’). Scalar refers to per tensor quantization
whereas N-D refers to per row or per column quantization. If the input is 2D of shape [M, K] then
zero point and scale tensor may be an M element vector [v_1, v_2, …, v_M] for per row quantization
and K element vector of shape [v_1, v_2, …, v_K] for per column quantization. If the input is N-D
tensor with shape [D1, D2, M, K] then zero point and scale tensor may have shape [D1, D2, M, 1] for
per row quantization and shape [D1, D2, 1, K] for per column quantization. Production must never
overflow, and accumulation may overflow if and only if in 32 bits.
Inputs
a (heterogeneous) - T1: N-dimensional quantized matrix a
a_scale (heterogeneous) - tensor(float): scale of quantized input a
a_zero_point (heterogeneous) - T1: zero point of quantized input a
b (heterogeneous) - T2: N-dimensional quantized matrix b
b_scale (heterogeneous) - tensor(float): scale of quantized input b
b_zero_point (heterogeneous) - T2: zero point of quantized input b
y_scale (heterogeneous) - tensor(float): scale of quantized output y
y_zero_point (heterogeneous) - T3: zero point of quantized output y
Outputs
y (heterogeneous) - T3: Quantized matrix multiply results from a * b
Type Constraints
T1 in ( tensor(int8), tensor(uint8) ): Constrain input a and its zero point data type to 8-bit
integer tensor.
T2 in ( tensor(int8), tensor(uint8) ): Constrain input b and its zero point data type to 8-bit
integer tensor.
T3 in ( tensor(int8), tensor(uint8) ): Constrain output y and its zero point data type to 8-bit
integer tensor.
*/
struct
parse_qlinearmatmul
:
op_parser
<
parse_qlinearmatmul
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"QLinearMatMul"
}};
}
// basic type checking for QLinearMatMul Operator
void
check_inputs
(
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
if
(
args
.
size
()
<
8
)
MIGRAPHX_THROW
(
"QLINEARMATMUL: missing inputs"
);
const
auto
&
in_a
=
args
[
0
];
const
auto
&
in_b
=
args
[
3
];
auto
sh_a
=
in_a
->
get_shape
();
auto
sh_b
=
in_b
->
get_shape
();
auto
type_a
=
sh_a
.
type
();
auto
type_b
=
sh_b
.
type
();
if
(
type_a
!=
migraphx
::
shape
::
int8_type
and
type_a
!=
migraphx
::
shape
::
uint8_type
)
MIGRAPHX_THROW
(
"QLINEARMATMUL: unsupported input type"
);
if
(
type_b
!=
migraphx
::
shape
::
int8_type
and
type_b
!=
migraphx
::
shape
::
uint8_type
)
MIGRAPHX_THROW
(
"QLINEARMATMUL: unsupported input type"
);
auto
lens_a
=
sh_a
.
lens
();
auto
lens_b
=
sh_b
.
lens
();
size_t
dim_a
=
lens_a
.
size
();
size_t
dim_b
=
lens_b
.
size
();
if
(
dim_a
==
0
or
dim_b
==
0
)
MIGRAPHX_THROW
(
"QLINEARMATMUL: empty input"
);
// broadcast supported if either is 1-D -- the other can be a 2-D tensor.
// if it is 1-D, just prepend/append that lens and check further constraints..
if
(
dim_a
==
1
)
{
lens_a
.
insert
(
lens_a
.
begin
(),
1
);
dim_a
++
;
}
if
(
dim_b
==
1
)
{
lens_b
.
push_back
(
1
);
dim_b
++
;
}
// 2-D or higher-order mat mul
if
(
dim_a
!=
dim_b
or
*
lens_a
.
rbegin
()
!=
*
(
lens_b
.
rbegin
()
+
1
)
or
not
std
::
equal
(
lens_a
.
rbegin
()
+
2
,
lens_a
.
rend
(),
lens_b
.
rbegin
()
+
2
,
lens_b
.
rend
()))
MIGRAPHX_THROW
(
"QLINEARMATMUL: mismatched input dimensions"
);
if
(
migraphx
::
any_of
({
args
[
1
],
args
[
2
],
args
[
4
],
args
[
5
]},
[](
auto
arg
)
{
return
not
arg
->
get_shape
().
scalar
();
}))
MIGRAPHX_THROW
(
"QLINEARMATMUL: unsupported row/column quantization"
);
}
instruction_ref
parse
(
const
op_desc
&
/* opd */
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
check_inputs
(
args
);
// A
const
auto
&
in_a
=
args
[
0
];
const
auto
&
in_scale_a
=
args
[
1
];
const
auto
&
in_zero_pt_a
=
args
[
2
];
auto
dquant_a
=
bcast_qdq_instr
(
"dequantizelinear"
,
in_a
,
in_scale_a
,
in_zero_pt_a
,
info
);
// B
const
auto
&
in_b
=
args
[
3
];
const
auto
&
in_scale_b
=
args
[
4
];
const
auto
&
in_zero_pt_b
=
args
[
5
];
auto
dquant_b
=
bcast_qdq_instr
(
"dequantizelinear"
,
in_b
,
in_scale_b
,
in_zero_pt_b
,
info
);
bool
is_a_prepended
=
false
;
bool
is_b_appended
=
false
;
// un-squeeze either tensor if 1-D.
if
(
in_a
->
get_shape
().
ndim
()
==
1
)
{
is_a_prepended
=
true
;
dquant_a
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
dquant_a
);
}
if
(
in_b
->
get_shape
().
ndim
()
==
1
)
{
is_b_appended
=
true
;
dquant_b
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
dquant_b
);
}
// Y = A * B
auto
out_y
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
dquant_a
,
dquant_b
);
// squeeze just once if necessary.. not twice.
if
(
is_a_prepended
)
out_y
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
out_y
);
else
if
(
is_b_appended
)
out_y
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
1
}}}),
out_y
);
const
auto
&
scale_y
=
args
[
6
];
const
auto
&
zero_pt_y
=
args
[
7
];
return
bcast_qdq_instr
(
"quantizelinear"
,
out_y
,
scale_y
,
zero_pt_y
,
info
);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_reshape.cpp
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -45,16 +45,25 @@ struct parse_reshape : op_parser<parse_reshape>
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"shape"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
args
[
0
]);
}
if
(
args
.
size
()
==
2
)
else
{
// 2 inputs
auto
s
=
args
[
1
]
->
eval
();
check_arg_empty
(
s
,
"Reshape: non-constant shape input is not supported"
);
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
if
(
s
.
empty
())
{
// arg[1] not eval-able
auto
alloc_ins
=
info
.
add_instruction
(
make_op
(
"allocate"
,
{{
"buf_type"
,
args
[
0
]
->
get_shape
().
type
()}}),
args
[
1
]);
return
info
.
add_instruction
(
make_op
(
"reshape"
),
args
[
0
],
alloc_ins
);
}
else
{
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
args
[
0
]);
}
}
auto
cont
=
info
.
add_instruction
(
make_op
(
"contiguous"
),
args
[
0
]);
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
cont
);
}
};
...
...
src/onnx/parse_resize.cpp
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -97,22 +97,19 @@ const auto& get_original_idx_op(const std::string& mode)
static
std
::
vector
<
int
>
calc_neighbor_points
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
size_t
>>>&
vvv_ind
,
int
i_dim
,
const
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
&
vec_dims
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
vec_dims
,
const
shape
&
in_s
)
{
if
(
i_dim
==
vvv_ind
.
size
())
{
std
::
vector
<
int
>
vec_ind
;
vec_ind
.
resize
(
vec_dims
.
size
());
std
::
vector
<
int
>
vec_ind
(
vec_dims
.
size
());
std
::
transform
(
vec_dims
.
begin
(),
vec_dims
.
end
(),
vec_ind
.
begin
(),
[
&
](
auto
idx
)
{
return
static_cast
<
int
>
(
in_s
.
index
(
idx
));
});
return
vec_ind
;
}
const
auto
&
vv_ind
=
vvv_ind
[
i_dim
];
const
auto
&
vv_lo
=
vv_ind
.
at
(
0
);
const
auto
&
vv_lo
=
vvv_ind
[
i_dim
][
0
];
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
vec_dims1
;
for
(
std
::
size_t
start
=
0
;
start
<
vec_dims
.
size
();
start
+=
vv_lo
.
size
())
{
...
...
@@ -126,8 +123,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
});
}
const
auto
&
vv_hi
=
vv_ind
.
at
(
1
)
;
for
(
std
::
size_t
start
=
0
;
start
<
vec_dims
.
size
();
start
+=
vv_
lo
.
size
())
const
auto
&
vv_hi
=
vv
v
_ind
[
i_dim
][
1
]
;
for
(
std
::
size_t
start
=
0
;
start
<
vec_dims
.
size
();
start
+=
vv_
hi
.
size
())
{
std
::
transform
(
vv_hi
.
begin
(),
vv_hi
.
end
(),
...
...
@@ -138,8 +135,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
return
dim
;
});
}
return
calc_neighbor_points
(
vvv_ind
,
i_dim
+
1
,
vec_dims1
,
in_s
);
vec_dims
.
clear
();
return
calc_neighbor_points
(
vvv_ind
,
i_dim
+
1
,
std
::
move
(
vec_dims1
)
,
in_s
);
}
static
std
::
string
get_coord_trans_mode
(
const
onnx_parser
::
attribute_map
&
attr
)
...
...
@@ -240,7 +237,7 @@ struct parse_resize : op_parser<parse_resize>
auto
arg_out_s
=
arg
->
eval
();
check_arg_empty
(
arg_out_s
,
"PARSE_"
+
opd
.
op_name
+
": dynamic output size is not supported!"
);
arg_out_s
.
visit
([
&
](
auto
ol
)
{
out_lens
.
assign
(
ol
.
begin
(),
ol
.
end
());
});
arg_out_s
.
visit
([
&
](
const
auto
&
ol
)
{
out_lens
.
assign
(
ol
.
begin
(),
ol
.
end
());
});
if
(
out_lens
.
size
()
!=
in_lens
.
size
())
{
...
...
@@ -267,7 +264,7 @@ struct parse_resize : op_parser<parse_resize>
"PARSE_"
+
opd
.
op_name
+
": dynamic input scale is not supported!"
);
arg_scale
.
visit
([
&
](
auto
v
)
{
vec_scale
.
assign
(
v
.
begin
(),
v
.
end
());
});
arg_scale
.
visit
([
&
](
const
auto
&
v
)
{
vec_scale
.
assign
(
v
.
begin
(),
v
.
end
());
});
if
(
in_lens
.
size
()
!=
vec_scale
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
...
...
@@ -300,15 +297,15 @@ struct parse_resize : op_parser<parse_resize>
// map out_idx to in_idx
auto
nearest_op
=
get_nearest_op
(
nearest_mode
);
shape_for_each
(
out_s
,
[
&
](
auto
idx
)
{
auto
in_idx
=
idx
;
shape_for_each
(
out_s
,
[
&
](
const
auto
&
out_idx_v
,
size_t
out_
idx
)
{
std
::
vector
<
size_t
>
in_idx
(
out_idx_v
.
size
())
;
for
(
auto
ii
=
0
;
ii
<
in_lens
.
size
();
++
ii
)
{
auto
idx_val
=
idx_op
(
in_lens
[
ii
],
out_lens
[
ii
],
idx
[
ii
],
vec_scale
[
ii
]);
auto
idx_val
=
idx_op
(
in_lens
[
ii
],
out_lens
[
ii
],
out_
idx
_v
[
ii
],
vec_scale
[
ii
]);
in_idx
[
ii
]
=
nearest_op
(
in_lens
[
ii
],
idx_val
);
}
ind
[
out_
s
.
index
(
idx
)
]
=
static_cast
<
int64_t
>
(
in_s
.
index
(
in_idx
));
ind
[
out_idx
]
=
static_cast
<
int64_t
>
(
in_s
.
index
(
in_idx
));
});
shape
ind_s
{
shape
::
int32_type
,
out_lens
};
...
...
@@ -327,20 +324,18 @@ struct parse_resize : op_parser<parse_resize>
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
size_t
>>>
vvv_ind
(
n_dim
,
vv_ind
);
std
::
vector
<
std
::
vector
<
float
>>
delta
(
n_dim
,
std
::
vector
<
float
>
(
out_elements
));
shape_for_each
(
out_s
,
[
&
](
auto
idx
)
{
auto
in_idx
=
idx
;
auto
out_idx
=
out_s
.
index
(
idx
);
shape_for_each
(
out_s
,
[
&
](
const
auto
&
out_idx_v
,
size_t
out_idx
)
{
for
(
auto
ii
=
0
;
ii
<
in_lens
.
size
();
++
ii
)
{
auto
idx_val
=
idx_op
(
in_lens
[
ii
],
out_lens
[
ii
],
idx
[
ii
],
vec_scale
[
ii
]);
auto
idx_val
=
idx_op
(
in_lens
[
ii
],
out_lens
[
ii
],
out_
idx
_v
[
ii
],
vec_scale
[
ii
]);
vvv_ind
[
ii
][
0
][
out_idx
]
=
nearest_floor
(
in_lens
[
ii
],
idx_val
);
vvv_ind
[
ii
][
1
][
out_idx
]
=
nearest_ceil
(
in_lens
[
ii
],
idx_val
);
delta
[
ii
][
out_idx
]
=
idx_val
-
vvv_ind
[
ii
][
0
][
out_idx
];
}
});
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
vec_dims
(
out_eleme
nts
);
auto
ind
=
calc_neighbor_points
(
vvv_ind
,
0
,
vec_dims
,
in_s
);
auto
ind
=
calc_neighbor_poi
nts
(
vvv_ind
,
0
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
(
out_elements
)
,
in_s
);
auto
ind_lens
=
out_lens
;
ind_lens
[
0
]
*=
(
std
::
size_t
{
1
}
<<
n_dim
);
shape
ind_s
{
shape
::
int32_type
,
ind_lens
};
...
...
src/onnx/parse_roialign.cpp
View file @
8d32c6b8
...
...
@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign>
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"RoiAlign"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*
parser
*/
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
std
::
string
coord_trans_mode
=
"half_pixel"
;
if
(
contains
(
info
.
attributes
,
"coordinate_transformation_mode"
))
std
::
string
coord_trans_mode
=
parser
.
opset_version
>=
16
?
"half_pixel"
:
"output_half_pixel"
;
if
(
const
auto
*
a
=
"coordinate_transformation_mode"
;
contains
(
info
.
attributes
,
a
))
{
coord_trans_mode
=
info
.
attributes
.
at
(
"coordinate_transformation_mode"
).
s
();
coord_trans_mode
=
info
.
attributes
.
at
(
a
).
s
();
}
if
(
not
contains
({
"half_pixel"
,
"output_half_pixel"
},
coord_trans_mode
))
{
MIGRAPHX_THROW
(
"coordinate_transformation_mode
\"
"
+
coord_trans_mode
+
...
...
src/onnx/parse_slice.cpp
View file @
8d32c6b8
...
...
@@ -34,16 +34,65 @@ namespace onnx {
struct
parse_slice
:
op_parser
<
parse_slice
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Slice"
}};
}
struct
slice_desc
{
op
::
slice
op
;
std
::
vector
<
instruction_ref
>
op_args
;
std
::
vector
<
int64_t
>
steps
;
std
::
vector
<
int64_t
>
raxes
;
void
always_insert
(
instruction_ref
arg
)
{
op_args
.
insert
(
op_args
.
begin
(),
arg
);
}
std
::
vector
<
int64_t
>
insert
(
instruction_ref
arg
)
{
std
::
vector
<
int64_t
>
result
;
migraphx
::
argument
arg_value
=
arg
->
eval
();
if
(
arg_value
.
empty
())
{
op_args
.
insert
(
op_args
.
begin
(),
arg
);
}
else
{
arg_value
.
visit
([
&
](
auto
s
)
{
result
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
return
result
;
}
};
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>
&
args
)
const
{
op
::
slice
op
;
auto
sd
=
construct_slice_desc
(
parser
,
info
,
args
);
auto
ins
=
info
.
add_instruction
(
sd
.
op
,
sd
.
op_args
);
if
(
not
sd
.
raxes
.
empty
())
{
ins
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
sd
.
raxes
}}),
ins
);
}
// If any steps are other than default 1, add a "steps" op
if
(
std
::
any_of
(
sd
.
steps
.
begin
(),
sd
.
steps
.
end
(),
[](
auto
s
)
{
return
std
::
abs
(
s
)
!=
1
;
}))
{
std
::
vector
<
int64_t
>
nsteps
;
std
::
transform
(
sd
.
steps
.
begin
(),
sd
.
steps
.
end
(),
std
::
back_inserter
(
nsteps
),
[](
auto
s
)
{
return
std
::
abs
(
s
);
});
return
ins
=
info
.
add_instruction
(
make_op
(
"step"
,
{{
"axes"
,
sd
.
op
.
axes
},
{
"steps"
,
nsteps
}}),
ins
);
}
else
return
ins
;
}
std
::
vector
<
int64_t
>
steps
;
slice_desc
construct_slice_desc
(
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
slice_desc
sd
;
// slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice.
...
...
@@ -51,89 +100,73 @@ struct parse_slice : op_parser<parse_slice>
{
migraphx
::
argument
step_arg
=
args
.
back
()
->
eval
();
check_arg_empty
(
step_arg
,
"PARSE_SLICE: cannot handle variable steps for slice"
);
step_arg
.
visit
([
&
](
auto
s
)
{
steps
.
assign
(
s
.
begin
(),
s
.
end
());
});
step_arg
.
visit
([
&
](
auto
s
)
{
sd
.
steps
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
if
(
args
.
size
()
>=
4
)
{
migraphx
::
argument
axes_arg
=
args
.
at
(
3
)
->
eval
();
check_arg_empty
(
axes_arg
,
"PARSE_SLICE: cannot handle variable axes for slice"
);
axes_arg
.
visit
([
&
](
auto
s
)
{
op
.
axes
.
assign
(
s
.
begin
(),
s
.
end
());
});
sd
.
op
.
axes
=
sd
.
insert
(
args
.
at
(
3
));
}
else
if
(
contains
(
info
.
attributes
,
"axes"
))
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axes"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
axes
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
sd
.
op
.
axes
));
});
}
if
(
args
.
size
()
>=
3
)
{
migraphx
::
argument
end_arg
=
args
.
at
(
2
)
->
eval
();
check_arg_empty
(
end_arg
,
"PARSE_SLICE: cannot handle variable ends for slice"
);
end_arg
.
visit
([
&
](
auto
s
)
{
op
.
ends
.
assign
(
s
.
begin
(),
s
.
end
());
});
sd
.
op
.
ends
=
sd
.
insert
(
args
.
at
(
2
));
}
else
if
(
contains
(
info
.
attributes
,
"ends"
))
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"ends"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
ends
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
sd
.
op
.
ends
));
});
}
if
(
args
.
size
()
>=
2
)
{
migraphx
::
argument
start_arg
=
args
.
at
(
1
)
->
eval
();
check_arg_empty
(
start_arg
,
"PARSE_SLICE: cannot handle variable starts for slice"
);
start_arg
.
visit
([
&
](
auto
s
)
{
op
.
starts
.
assign
(
s
.
begin
(),
s
.
end
());
});
sd
.
op
.
starts
=
sd
.
insert
(
args
.
at
(
1
));
}
else
if
(
contains
(
info
.
attributes
,
"starts"
))
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"starts"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
starts
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
sd
.
op
.
starts
));
});
}
// data input argument
sd
.
always_insert
(
args
.
at
(
0
));
// If axes arg is not given, the default is all of them.
if
(
op
.
axes
.
empty
())
if
(
sd
.
op
.
axes
.
empty
()
and
sd
.
op_args
.
size
()
<
3
)
{
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
ndim
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
int64_t
{
0
});
op
.
axes
=
axes
;
sd
.
op
.
axes
=
axes
;
}
std
::
vector
<
int64_t
>
raxes
;
if
(
not
sd
.
steps
.
empty
())
{
if
(
sd
.
op
.
starts
.
empty
()
or
sd
.
op
.
ends
.
empty
())
MIGRAPHX_THROW
(
"PARSE_SLICE: steps and variable starts and ends is not supported"
);
if
(
sd
.
op
.
axes
.
empty
())
MIGRAPHX_THROW
(
"PARSE_SLICE: steps and variable axes is not supported"
);
}
assert
(
steps
.
empty
()
or
steps
.
size
()
==
op
.
axes
.
size
());
assert
(
op
.
axes
.
size
()
==
op
.
starts
.
size
());
assert
(
op
.
axes
.
size
()
==
op
.
ends
.
size
());
assert
(
sd
.
steps
.
empty
()
or
sd
.
steps
.
size
()
==
sd
.
op
.
axes
.
size
());
// If any axes have negative step, prepare to add a "reverse" op
for
(
auto
i
:
range
(
steps
.
size
()))
for
(
auto
i
:
range
(
sd
.
steps
.
size
()))
{
if
(
steps
[
i
]
>=
0
)
if
(
sd
.
steps
[
i
]
>=
0
)
continue
;
op
.
starts
[
i
]
+=
1
;
if
(
op
.
starts
[
i
]
==
0
)
op
.
starts
[
i
]
=
INT_MAX
;
op
.
ends
[
i
]
+=
1
;
raxes
.
push_back
(
op
.
axes
[
i
]);
std
::
swap
(
op
.
starts
[
i
],
op
.
ends
[
i
]);
}
auto
ins
=
info
.
add_instruction
(
op
,
args
[
0
]);
if
(
not
raxes
.
empty
())
{
ins
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
raxes
}}),
ins
);
sd
.
op
.
starts
[
i
]
+=
1
;
if
(
sd
.
op
.
starts
[
i
]
==
0
)
sd
.
op
.
starts
[
i
]
=
INT_MAX
;
sd
.
op
.
ends
[
i
]
+=
1
;
sd
.
raxes
.
push_back
(
sd
.
op
.
axes
[
i
]);
std
::
swap
(
sd
.
op
.
starts
[
i
],
sd
.
op
.
ends
[
i
]);
}
// If any steps are other than default 1, add a "steps" op
if
(
std
::
any_of
(
steps
.
begin
(),
steps
.
end
(),
[](
auto
s
)
{
return
std
::
abs
(
s
)
!=
1
;
}))
{
std
::
vector
<
int64_t
>
nsteps
;
std
::
transform
(
steps
.
begin
(),
steps
.
end
(),
std
::
back_inserter
(
nsteps
),
[](
auto
s
)
{
return
std
::
abs
(
s
);
});
return
ins
=
info
.
add_instruction
(
make_op
(
"step"
,
{{
"axes"
,
op
.
axes
},
{
"steps"
,
nsteps
}}),
ins
);
}
else
return
ins
;
return
sd
;
}
};
...
...
src/onnx/parse_spacetodepth.cpp
View file @
8d32c6b8
...
...
@@ -73,8 +73,7 @@ struct parse_spacetodepth : op_parser<parse_spacetodepth>
std
::
vector
<
int64_t
>
perm
=
{
0
,
3
,
5
,
1
,
2
,
4
};
auto
temp1
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
trans_lens
}}),
args
[
0
]);
auto
temp2
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
temp1
);
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
res_lens
}}),
info
.
make_contiguous
(
temp2
));
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
res_lens
}}),
temp2
);
}
};
...
...
src/optimize_module.cpp
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
mpm
.
run_pass
(
simplify_reshapes
{});
mpm
.
run_pass
(
simplify_algebra
{});
// loop to further optimize after initial transformations
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
mpm
.
run_pass
(
simplify_reshapes
{});
mpm
.
run_pass
(
simplify_algebra
{});
}
mpm
.
run_pass
(
eliminate_common_subexpression
{});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
propagate_constant
{});
...
...
src/pad_calc.cpp
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -52,6 +52,11 @@ void calculate_padding(int64_t idx,
}
}
/**
* Given the input array dimensions; kernel (wei_lens); strides; and dilations,
* calculate the padding value in each dimension.
*
*/
std
::
vector
<
std
::
size_t
>
calc_dyn_auto_pad
(
const
std
::
vector
<
std
::
size_t
>&
input_lens
,
const
std
::
vector
<
std
::
size_t
>&
wei_lens
,
const
std
::
vector
<
std
::
size_t
>&
strides
,
...
...
@@ -60,6 +65,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input
{
std
::
vector
<
std
::
size_t
>
padding
;
assert
(
input_lens
.
size
()
>=
3
);
assert
(
input_lens
.
size
()
==
wei_lens
.
size
());
std
::
size_t
num_spatial_dims
=
input_lens
.
size
()
-
2
;
padding
.
resize
(
2
*
num_spatial_dims
);
for
(
std
::
size_t
i
=
0
;
i
<
num_spatial_dims
;
i
++
)
...
...
@@ -88,6 +94,11 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input
return
padding
;
}
/**
* Calculate the correct output shape for a convolution with
* a given input size and other parameters.
*
*/
shape
compute_padded_shape
(
const
shape
&
input
,
const
shape
&
weights
,
const
std
::
vector
<
std
::
size_t
>&
padding
,
...
...
@@ -111,5 +122,33 @@ shape compute_padded_shape(const shape& input,
return
input
.
with_lens
(
output_lens
);
}
/**
* Calculate the correct output shape for a pooling with
* a given input size and other parameters. This uses
* the same formula for pooling that compute_padded_shape() uses
* for convolutions, but takes slightly different inputs.
*
*/
shape
compute_padded_pool_shape
(
const
shape
&
input
,
const
shape
&
kernel
,
const
std
::
vector
<
std
::
size_t
>&
padding
,
const
std
::
vector
<
std
::
size_t
>&
stride
,
const
std
::
vector
<
std
::
size_t
>&
dilation
)
{
const
size_t
num_spatial_dims
=
input
.
lens
().
size
()
-
2
;
std
::
vector
<
size_t
>
output_lens
{
input
.
lens
()[
0
],
input
.
lens
()[
1
]};
// calculate the output shape of the pooling: ((W - K + 2P) / S) + 1
for
(
size_t
i
=
0
;
i
<
num_spatial_dims
;
++
i
)
{
auto
padding_factor
=
padding
[
i
]
+
padding
[
i
+
num_spatial_dims
];
output_lens
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
kernel
.
lens
()[
i
]
-
1
))
+
padding_factor
)
/
stride
[
i
]
+
1
)));
}
return
input
.
with_lens
(
output_lens
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/process.cpp
View file @
8d32c6b8
...
...
@@ -26,13 +26,23 @@
#include <migraphx/env.hpp>
#include <functional>
#include <iostream>
#include <optional>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#else
#include <unistd.h>
#endif
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_CMD_EXECUTE
)
#ifndef _WIN32
std
::
function
<
void
(
const
char
*
)
>
redirect_to
(
std
::
ostream
&
os
)
{
return
[
&
](
const
char
*
x
)
{
os
<<
x
;
};
...
...
@@ -74,6 +84,155 @@ int exec(const std::string& cmd, std::function<void(process::writer)> std_in)
});
}
#else
constexpr
std
::
size_t
MIGRAPHX_PROCESS_BUFSIZE
=
4096
;
class
pipe
{
public:
explicit
pipe
(
bool
inherit_handle
=
true
)
{
SECURITY_ATTRIBUTES
attrs
;
attrs
.
nLength
=
sizeof
(
SECURITY_ATTRIBUTES
);
attrs
.
bInheritHandle
=
inherit_handle
?
TRUE
:
FALSE
;
attrs
.
lpSecurityDescriptor
=
nullptr
;
if
(
CreatePipe
(
&
m_read
,
&
m_write
,
&
attrs
,
0
)
==
FALSE
)
throw
GetLastError
();
if
(
SetHandleInformation
(
&
m_read
,
HANDLE_FLAG_INHERIT
,
0
)
==
FALSE
)
throw
GetLastError
();
}
pipe
(
const
pipe
&
)
=
delete
;
pipe
&
operator
=
(
const
pipe
&
)
=
delete
;
pipe
(
pipe
&&
)
=
default
;
~
pipe
()
{
CloseHandle
(
m_read
);
m_read
=
nullptr
;
CloseHandle
(
m_write
);
m_write
=
nullptr
;
}
std
::
optional
<
std
::
pair
<
bool
,
DWORD
>>
read
(
LPVOID
buffer
,
DWORD
length
)
const
{
DWORD
bytes_read
;
if
(
ReadFile
(
m_read
,
buffer
,
length
,
&
bytes_read
,
nullptr
)
==
FALSE
)
{
DWORD
error
{
GetLastError
()};
if
(
error
!=
ERROR_MORE_DATA
)
{
return
std
::
nullopt
;
}
return
{{
true
,
bytes_read
}};
}
return
{{
false
,
bytes_read
}};
}
HANDLE
get_read_handle
()
const
{
return
m_read
;
}
bool
write
(
LPCVOID
buffer
,
DWORD
length
)
const
{
DWORD
bytes_written
;
return
WriteFile
(
m_write
,
buffer
,
length
,
&
bytes_written
,
nullptr
)
==
TRUE
;
}
HANDLE
get_write_handle
()
const
{
return
m_write
;
}
private:
HANDLE
m_write
=
nullptr
,
m_read
=
nullptr
;
};
template
<
typename
F
>
int
exec
(
const
std
::
string
&
cmd
,
F
f
)
{
try
{
if
(
enabled
(
MIGRAPHX_TRACE_CMD_EXECUTE
{}))
std
::
cout
<<
cmd
<<
std
::
endl
;
STARTUPINFO
info
;
PROCESS_INFORMATION
process_info
;
pipe
in
{},
out
{};
ZeroMemory
(
&
info
,
sizeof
(
STARTUPINFO
));
info
.
cb
=
sizeof
(
STARTUPINFO
);
info
.
hStdError
=
out
.
get_write_handle
();
info
.
hStdOutput
=
out
.
get_write_handle
();
info
.
hStdInput
=
in
.
get_read_handle
();
info
.
dwFlags
|=
STARTF_USESTDHANDLES
;
ZeroMemory
(
&
process_info
,
sizeof
(
process_info
));
if
(
CreateProcess
(
nullptr
,
const_cast
<
LPSTR
>
(
cmd
.
c_str
()),
nullptr
,
nullptr
,
TRUE
,
0
,
nullptr
,
nullptr
,
&
info
,
&
process_info
)
==
FALSE
)
{
return
GetLastError
();
}
f
(
in
,
out
);
WaitForSingleObject
(
process_info
.
hProcess
,
INFINITE
);
DWORD
status
{};
GetExitCodeProcess
(
process_info
.
hProcess
,
&
status
);
CloseHandle
(
process_info
.
hProcess
);
CloseHandle
(
process_info
.
hThread
);
return
static_cast
<
int
>
(
status
);
}
// cppcheck-suppress catchExceptionByValue
catch
(
DWORD
last_error
)
{
return
last_error
;
}
}
int
exec
(
const
std
::
string
&
cmd
)
{
TCHAR
buffer
[
MIGRAPHX_PROCESS_BUFSIZE
];
HANDLE
std_out
{
GetStdHandle
(
STD_OUTPUT_HANDLE
)};
return
(
std_out
==
nullptr
or
std_out
==
INVALID_HANDLE_VALUE
)
?
GetLastError
()
:
exec
(
cmd
,
[
&
](
const
pipe
&
,
const
pipe
&
out
)
{
for
(;;)
{
if
(
auto
result
=
out
.
read
(
buffer
,
MIGRAPHX_PROCESS_BUFSIZE
))
{
auto
[
more_data
,
bytes_read
]
=
*
result
;
if
(
not
more_data
or
bytes_read
==
0
)
break
;
DWORD
written
;
if
(
WriteFile
(
std_out
,
buffer
,
bytes_read
,
&
written
,
nullptr
)
==
FALSE
)
break
;
}
}
});
}
int
exec
(
const
std
::
string
&
cmd
,
std
::
function
<
void
(
process
::
writer
)
>
std_in
)
{
return
exec
(
cmd
,
[
&
](
const
pipe
&
in
,
const
pipe
&
)
{
std_in
([
&
](
const
char
*
buffer
,
std
::
size_t
n
)
{
in
.
write
(
buffer
,
n
);
});
});
}
#endif
struct
process_impl
{
std
::
string
command
{};
...
...
@@ -119,7 +278,14 @@ process& process::cwd(const fs::path& p)
return
*
this
;
}
void
process
::
exec
()
{
impl
->
check_exec
(
impl
->
get_command
(),
redirect_to
(
std
::
cout
));
}
void
process
::
exec
()
{
#ifndef _WIN32
impl
->
check_exec
(
impl
->
get_command
(),
redirect_to
(
std
::
cout
));
#else
impl
->
check_exec
(
impl
->
get_command
());
#endif
}
void
process
::
write
(
std
::
function
<
void
(
process
::
writer
)
>
pipe_in
)
{
...
...
src/program.cpp
View file @
8d32c6b8
...
...
@@ -347,7 +347,7 @@ void program::finalize()
template
<
class
T
>
std
::
string
classify
(
T
x
)
{
switch
(
std
::
fpclassify
(
x
))
switch
(
std
::
fpclassify
(
static_cast
<
double
>
(
x
)
))
{
case
FP_INFINITE
:
return
"inf"
;
case
FP_NAN
:
return
"nan"
;
...
...
@@ -624,7 +624,7 @@ std::string get_migraphx_version()
program file version is for the data structure or format of the MXR file. Version should be bumped
if any changes occur to the format of the MXR file.
*/
const
int
program_file_version
=
6
;
const
int
program_file_version
=
7
;
value
program
::
to_value
()
const
{
...
...
src/propagate_constant.cpp
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_PROPAGATE_CONSTANT
)
bool
skip_prop
o
gate
(
instruction_ref
ins
)
bool
skip_prop
a
gate
(
instruction_ref
ins
)
{
if
(
ins
->
name
()
==
"contiguous"
)
return
skip_prop
o
gate
(
ins
->
inputs
().
front
());
return
skip_prop
a
gate
(
ins
->
inputs
().
front
());
auto
&&
s
=
ins
->
get_shape
();
if
(
s
.
broadcasted
()
and
not
s
.
scalar
())
return
true
;
...
...
@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins)
return
false
;
}
bool
is_const_ins
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_prop
o
gate
(
ins
);
}
bool
is_const_ins
(
instruction_ref
ins
)
{
return
ins
->
can_eval
()
and
not
skip_prop
a
gate
(
ins
);
}
void
propagate_constant
::
apply
(
module
&
m
)
const
{
...
...
src/py/CMakeLists.txt
View file @
8d32c6b8
...
...
@@ -22,27 +22,24 @@
# THE SOFTWARE.
#####################################################################################
option
(
MIGRAPHX_ENABLE_PYTHON
"Enable python bindings"
ON
)
add_library
(
migraphx_py py_loader.cpp
)
migraphx_generate_export_header
(
migraphx_py
)
target_include_directories
(
migraphx_py PRIVATE include
)
target_link_libraries
(
migraphx_py PUBLIC migraphx
)
rocm_install_targets
(
TARGETS migraphx_py INCLUDE include
)
if
(
MIGRAPHX_ENABLE_PYTHON
)
include
(
PythonModules
)
include
(
PythonModules
)
foreach
(
PYTHON_VERSION
${
PYTHON_VERSIONS
}
)
py_add_module
(
migraphx_pybind_
${
PYTHON_VERSION
}
migraphx_py.cpp PYTHON_VERSION
${
PYTHON_VERSION
}
PYTHON_MODULE migraphx
)
target_link_libraries
(
migraphx_pybind_
${
PYTHON_VERSION
}
PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets
)
rocm_install_targets
(
TARGETS migraphx_pybind_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_pybind_
${
PYTHON_VERSION
}
)
add_library
(
migraphx_py_
${
PYTHON_VERSION
}
py.cpp
)
target_include_directories
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE include
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PUBLIC migraphx
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE pybind11::pybind11 python
${
PYTHON_VERSION
}
::runtime
)
rocm_install_targets
(
TARGETS migraphx_py_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_py_
${
PYTHON_VERSION
}
)
endforeach
()
endif
()
foreach
(
PYTHON_VERSION
${
PYTHON_VERSIONS
}
)
py_add_module
(
migraphx_pybind_
${
PYTHON_VERSION
}
migraphx_py.cpp PYTHON_VERSION
${
PYTHON_VERSION
}
PYTHON_MODULE migraphx
)
target_link_libraries
(
migraphx_pybind_
${
PYTHON_VERSION
}
PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets
)
rocm_install_targets
(
TARGETS migraphx_pybind_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_pybind_
${
PYTHON_VERSION
}
)
add_library
(
migraphx_py_
${
PYTHON_VERSION
}
py.cpp
)
target_include_directories
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE include
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PUBLIC migraphx
)
target_link_libraries
(
migraphx_py_
${
PYTHON_VERSION
}
PRIVATE pybind11::pybind11 python
${
PYTHON_VERSION
}
::runtime
)
rocm_install_targets
(
TARGETS migraphx_py_
${
PYTHON_VERSION
}
)
add_dependencies
(
migraphx_py migraphx_py_
${
PYTHON_VERSION
}
)
endforeach
()
src/quantization.cpp
View file @
8d32c6b8
...
...
@@ -70,6 +70,10 @@ void quantize_int8(program& prog,
MIGRAPHX_THROW
(
"QUANTIZE_INT8: only support DOT and CONVOLUTION operation"
);
}
// Run optimize_module() before converting to int8 to const eval and fold in FP32 to
// avoid loss of precision.
run_passes
(
prog
,
{
optimize_module
{}});
std
::
shared_ptr
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
int8_quant_params
=
std
::
make_shared
<
std
::
vector
<
std
::
pair
<
float
,
float
>>>
();
std
::
shared_ptr
<
std
::
vector
<
float
>>
max_abs_vals
=
std
::
make_shared
<
std
::
vector
<
float
>>
();
...
...
@@ -143,10 +147,7 @@ void quantize_int8(program& prog,
run_passes
(
prog
,
{
quantize_int8_pass
{
ins_names
,
*
int8_quant_params
},
eliminate_common_subexpression
{},
dead_code_elimination
{},
simplify_reshapes
{},
dead_code_elimination
{},
optimize_module
{},
simplify_qdq
{},
dead_code_elimination
{}});
}
...
...
src/rewrite_pooling.cpp
View file @
8d32c6b8
...
...
@@ -43,9 +43,7 @@ void rewrite_pooling::apply(module& m) const
continue
;
if
(
ins
->
inputs
().
empty
())
continue
;
auto
&&
s
=
ins
->
inputs
().
front
()
->
get_shape
();
if
(
not
s
.
standard
())
continue
;
auto
&&
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
&&
op
=
any_cast
<
op
::
pooling
>
(
ins
->
get_operator
());
if
(
not
std
::
all_of
(
op
.
padding
.
begin
(),
op
.
padding
.
end
(),
[](
auto
i
)
{
return
i
==
0
;
}))
continue
;
...
...
@@ -54,27 +52,18 @@ void rewrite_pooling::apply(module& m) const
auto
lens
=
s
.
lens
();
if
(
not
std
::
equal
(
lens
.
begin
()
+
2
,
lens
.
end
(),
op
.
lengths
.
begin
(),
op
.
lengths
.
end
()))
continue
;
std
::
int64_t
n
=
s
.
lens
()[
0
];
std
::
int64_t
c
=
s
.
lens
()[
1
];
auto
reshape
=
m
.
insert_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
{
n
*
c
,
-
1
}}}),
ins
->
inputs
().
front
());
instruction_ref
pooling
{};
std
::
vector
<
std
::
int64_t
>
axes
(
lens
.
size
()
-
2
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
2
);
// average pooling
if
(
op
.
mode
==
op
::
pooling_mode
::
average
)
{
pooling
=
m
.
insert
_instruction
(
ins
,
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
1
}
}}),
reshape
);
m
.
replace
_instruction
(
ins
,
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
ins
->
inputs
()
);
}
// max pooling
else
{
pooling
=
m
.
insert
_instruction
(
ins
,
make_op
(
"reduce_max"
,
{{
"axes"
,
{
1
}
}}),
reshape
);
m
.
replace
_instruction
(
ins
,
make_op
(
"reduce_max"
,
{{
"axes"
,
axes
}}),
ins
->
inputs
()
);
}
std
::
vector
<
int64_t
>
rsp_lens
(
lens
.
size
(),
1
);
rsp_lens
[
0
]
=
n
;
rsp_lens
[
1
]
=
c
;
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
rsp_lens
}}),
pooling
);
}
}
...
...
src/shape.cpp
View file @
8d32c6b8
...
...
@@ -50,13 +50,14 @@ struct shape_impl
{
assert
(
t
!=
shape
::
tuple_type
);
}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_standard
(
true
)
{
assert
(
t
!=
shape
::
tuple_type
);
this
->
calculate_strides
();
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
}
shape_impl
(
shape
::
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_strides
(
std
::
move
(
s
))
{
...
...
@@ -151,6 +152,22 @@ struct shape_impl
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
std
::
size_t
get_index
(
size_t
i
)
const
{
std
::
size_t
result
=
0
;
std
::
size_t
s
=
1
;
for
(
auto
k
:
migraphx
::
reverse
(
migraphx
::
range
(
m_lens
.
size
())))
{
std
::
size_t
stride
=
m_strides
[
k
];
std
::
size_t
len
=
m_lens
[
k
];
std
::
size_t
idx
=
(
i
%
(
s
*
len
))
/
s
;
result
+=
stride
*
idx
;
s
*=
len
;
}
return
result
;
}
std
::
vector
<
std
::
size_t
>
min_lens
()
const
{
std
::
vector
<
std
::
size_t
>
ret
(
m_dyn_dims
.
size
());
...
...
@@ -213,6 +230,7 @@ std::string shape::name(shape::type_t t)
}
MIGRAPHX_THROW
(
"Invalid type"
);
}
std
::
string
shape
::
cpp_type
(
shape
::
type_t
t
)
{
switch
(
t
)
...
...
@@ -229,10 +247,12 @@ std::string shape::cpp_type(shape::type_t t)
shape
::
shape
()
:
impl
(
shape_impl
::
default_shape
())
{}
shape
::
shape
(
type_t
t
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
))
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
l
)))
{
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
t
,
std
::
move
(
l
),
std
::
move
(
s
)))
{
...
...
@@ -358,21 +378,8 @@ std::size_t shape::index(std::size_t i) const
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
if
(
this
->
standard
())
return
i
;
else
{
std
::
size_t
s
=
1
;
std
::
size_t
result
=
0
;
for
(
std
::
size_t
j
=
0
;
j
<
this
->
lens
().
size
();
j
++
)
{
const
std
::
size_t
k
=
this
->
lens
().
size
()
-
j
-
1
;
const
std
::
size_t
stride
=
this
->
strides
()[
k
];
const
std
::
size_t
len
=
this
->
lens
()[
k
];
const
std
::
size_t
idx
=
(
i
%
(
s
*
len
))
/
s
;
result
+=
stride
*
idx
;
s
*=
len
;
}
return
result
;
}
return
impl
->
get_index
(
i
);
}
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
idx
)
const
...
...
src/simplify_algebra.cpp
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -521,6 +521,27 @@ struct find_inner_broadcast
})
<
(
lens
.
size
()
-
1
);
}))
return
;
if
(
broadcasts
.
size
()
>
1
)
{
auto
bcast_strides
=
broadcasts
.
front
()
->
get_shape
().
strides
().
size
();
std
::
vector
<
size_t
>
common_axis
(
bcast_strides
,
0
);
// go through the strides of each broadcast,
// keep track of values that are equal to 0 in a dimension
for
(
auto
i
=
0
;
i
<
bcast_strides
;
i
++
)
{
for
(
const
auto
&
broadcast
:
broadcasts
)
{
if
(
broadcast
->
get_shape
().
strides
()[
i
]
==
0
)
common_axis
[
i
]
++
;
}
}
// if no common broadcast axis, transformation is not useful
if
(
std
::
find_if
(
common_axis
.
begin
(),
common_axis
.
end
(),
[](
auto
num_common
)
{
return
num_common
>
1
;
})
==
common_axis
.
end
())
return
;
}
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
...
...
@@ -1325,48 +1346,59 @@ struct find_split_reshape
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
slc
=
r
.
instructions
[
"slice"
];
auto
rsp
=
r
.
instructions
[
"reshape"
];
auto
slc
=
r
.
instructions
[
"slice"
];
auto
rsp
=
r
.
instructions
[
"reshape"
];
auto
input
=
slc
->
inputs
().
front
();
// Only apply simplification when slices are on a single axis
auto
axes
=
any_cast
<
op
::
slice
>
(
slc
->
get_operator
()).
axes
;
if
(
axes
.
size
()
>
1
)
{
return
;
}
auto
input
=
slc
->
inputs
().
front
();
auto
split_outputs
=
get_splits
(
input
);
if
(
split_outputs
.
empty
())
{
return
;
}
// Only want to apply this optimization if each split output is followed by
// a contiguous op and a reshape
if
(
std
::
any_of
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
[](
auto
i
)
{
if
(
i
->
outputs
().
size
()
==
1
)
{
auto
cont
=
i
->
outputs
().
front
();
return
cont
->
outputs
().
size
()
!=
1
;
}
return
false
;
}))
// Find all the reshapes (similar to rsp) that can be simplified
std
::
vector
<
instruction_ref
>
conts
;
std
::
vector
<
instruction_ref
>
vec_rsp
;
// Iterate through slice and contiguous outputs to allow simplifications when
// slice is followed by multiple reshapes
for
(
auto
&
i
:
split_outputs
)
{
return
;
std
::
copy_if
(
i
->
outputs
().
begin
(),
i
->
outputs
().
end
(),
std
::
back_inserter
(
conts
),
[](
auto
j
)
{
return
j
->
name
()
==
"contiguous"
;
});
}
std
::
vector
<
instruction_ref
>
vec_rsp
(
split_outputs
.
size
());
std
::
transform
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
vec_rsp
.
begin
(),
[](
auto
i
)
{
auto
cont
=
i
->
outputs
().
front
();
return
cont
->
outputs
().
front
();
});
for
(
auto
&
i
:
conts
)
{
std
::
copy_if
(
i
->
outputs
().
begin
(),
i
->
outputs
().
end
(),
std
::
back_inserter
(
vec_rsp
),
[
&
](
auto
j
)
{
return
j
->
get_operator
()
==
rsp
->
get_operator
();
});
}
// all outputs are reshape and of the same shape
auto
dims
=
any_cast
<
op
::
reshape
>
(
rsp
->
get_operator
()).
dims
;
if
(
not
same_ops
(
vec_rsp
))
// No simplification needed if there is only one slice -> cont -> reshape
if
(
vec_rsp
.
size
()
<=
1
)
{
return
;
}
// ensure reshape happens after the axis dimension
auto
axis
=
any_cast
<
op
::
slice
>
(
slc
->
get_operator
()).
axes
[
0
];
auto
axis
=
axes
[
0
];
auto
slc_lens
=
slc
->
get_shape
().
lens
();
auto
slc_dim_size
=
std
::
accumulate
(
slc_lens
.
begin
()
+
axis
,
slc_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
auto
input_lens
=
input
->
get_shape
().
lens
();
auto
input_size
=
input
->
get_shape
().
elements
();
auto
slc_axis_len
=
input_lens
[
axis
];
// search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size
...
...
@@ -1393,16 +1425,67 @@ struct find_split_reshape
{
rsp_axis
=
std
::
distance
(
rsp_strides
.
begin
(),
ait
);
}
// calculate reshape output shape
std
::
vector
<
int64_t
>
vec_dims
(
vec_rsp
.
size
());
std
::
transform
(
vec_rsp
.
begin
(),
vec_rsp
.
end
(),
vec_dims
.
begin
(),
[
&
](
auto
is
)
{
return
is
->
get_shape
().
lens
()[
rsp_axis
];
});
// Calculate reshape output shape
// Need to find a reshape such that data represented by instructions in vec_rsp can be
// written as slices of this new reshape. This is done by holding all the dims constant in
// rsp_lens to compute the required dim for rsp_axis (axis that will be sliced)
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1, 2, 4}, rsp_fixed_size = 2*1*2*4 = 16
// rsp_axis_len = 2*12*4 / 16 = 6
// rsp_out_lens (final) = {2, 6, 2, 4}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1}, rsp_fixed_size = 2*1 = 2
// rsp_axis_len = 2*12*4 / 2 = 48
// rsp_out_lens (final) = {2, 48}
std
::
vector
<
int64_t
>
rsp_out_lens
(
rsp_lens
.
begin
(),
rsp_lens
.
end
());
rsp_out_lens
[
rsp_axis
]
=
1
;
auto
rsp_fixed_size
=
std
::
accumulate
(
rsp_out_lens
.
begin
(),
rsp_out_lens
.
end
(),
1
,
std
::
multiplies
<
std
::
size_t
>
());
rsp_out_lens
[
rsp_axis
]
=
std
::
accumulate
(
vec_dims
.
begin
(),
vec_dims
.
end
(),
std
::
int64_t
{
0
});
// cannot create a valid reshape for simplification
if
(
input_size
%
rsp_fixed_size
!=
0
)
{
return
;
}
auto
rsp_axis_len
=
input_size
/
rsp_fixed_size
;
rsp_out_lens
[
rsp_axis
]
=
rsp_axis_len
;
// Calculate new slice start and end indices. Indices are scaled using the new reshape axis
// and the original slice axis. See examples:
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// slc_axis_len = 12, rsp_axis_len = 6
// New Starts: {0*6/12, 4*6/12, 8*6/12} = {0, 2, 4}
// New Ends: {4*6/12, 8*6/12, 12*6/12} = {2, 4, 6}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// slc_axis_len = 12, rsp_axis_len = 48
// New Starts: {0*48/12, 4*48/12, 8*48/12} = { 0, 16, 32}
// New Ends: {4*48/12, 8*48/12, 12*48/12} = {16, 32, 48}
std
::
vector
<
int64_t
>
new_starts
(
vec_rsp
.
size
());
std
::
transform
(
vec_rsp
.
begin
(),
vec_rsp
.
end
(),
new_starts
.
begin
(),
[
&
](
auto
is
)
{
auto
cont
=
is
->
inputs
().
front
();
auto
og_slc
=
cont
->
inputs
().
front
();
return
any_cast
<
op
::
slice
>
(
og_slc
->
get_operator
()).
starts
[
0
]
*
rsp_axis_len
/
slc_axis_len
;
});
std
::
vector
<
int64_t
>
new_ends
(
vec_rsp
.
size
());
std
::
transform
(
vec_rsp
.
begin
(),
vec_rsp
.
end
(),
new_ends
.
begin
(),
[
&
](
auto
is
)
{
auto
cont
=
is
->
inputs
().
front
();
auto
og_slc
=
cont
->
inputs
().
front
();
return
any_cast
<
op
::
slice
>
(
og_slc
->
get_operator
()).
ends
[
0
]
*
rsp_axis_len
/
slc_axis_len
;
});
// insert the reshape instruction and add contiguous if needed
if
(
not
input
->
get_shape
().
standard
())
...
...
@@ -1413,16 +1496,14 @@ struct find_split_reshape
std
::
next
(
input
),
make_op
(
"reshape"
,
{{
"dims"
,
rsp_out_lens
}}),
input
);
// replace the original reshape with slice
int64_t
start
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
vec_rsp
.
size
();
++
i
)
{
m
.
replace_instruction
(
vec_rsp
[
i
],
make_op
(
"slice"
,
{{
"axes"
,
{
rsp_axis
}},
{
"starts"
,
{
start
}},
{
"ends"
,
{
start
+
vec_dim
s
[
i
]}}}),
{{
"axes"
,
{
rsp_axis
}},
{
"starts"
,
{
new_
start
s
[
i
]
}},
{
"ends"
,
{
new_end
s
[
i
]}}}),
rsp_ins
);
start
+=
vec_dims
[
i
];
}
}
};
...
...
@@ -1446,10 +1527,13 @@ struct find_split_transpose
{
return
;
}
if
(
std
::
any_of
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
[](
auto
i
)
{
return
i
->
outputs
().
size
()
!=
1
;
}))
return
;
std
::
vector
<
instruction_ref
>
vec_trans
(
split_outputs
.
size
());
std
::
transform
(
split_outputs
.
begin
(),
split_outputs
.
end
(),
vec_trans
.
begin
(),
[](
auto
i
)
{
assert
(
i
->
outputs
().
size
()
==
1
);
return
i
->
outputs
().
front
();
});
...
...
src/simplify_dyn_ops.cpp
0 → 100644
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
*/
struct
find_static_2in_broadcasts
{
auto
matcher
()
const
{
return
match
::
broadcast
(
match
::
nargs
(
2
),
match
::
arg
(
0
)(
match
::
static_shape
()),
match
::
arg
(
1
)(
match
::
static_shape
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
out_lens
=
ins
->
get_shape
().
lens
();
auto
broadcast_op
=
ins
->
get_operator
();
if
(
broadcast_op
.
name
()
==
"broadcast"
)
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
}});
}
else
{
broadcast_op
.
from_value
({{
"out_lens"
,
out_lens
},
{
"out_dyn_dims"
,
{}}});
}
m
.
replace_instruction
(
ins
,
broadcast_op
,
ins
->
inputs
().
at
(
0
));
}
};
/**
* Simplify slice with variable `starts` and `ends` to the constant version if
* the `input_starts` and `input_ends` inputs are constant.
*/
struct
find_const_3in_slice
{
auto
matcher
()
const
{
return
match
::
name
(
"slice"
)(
match
::
nargs
(
3
),
match
::
arg
(
1
)(
match
::
is_constant
()),
match
::
arg
(
2
)(
match
::
is_constant
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
inputs
=
ins
->
inputs
();
argument
starts_arg
=
inputs
.
at
(
1
)
->
eval
();
argument
ends_arg
=
inputs
.
at
(
2
)
->
eval
();
if
(
not
starts_arg
.
empty
()
and
not
ends_arg
.
empty
())
{
std
::
vector
<
int64_t
>
starts_vec
;
std
::
vector
<
int64_t
>
ends_vec
;
starts_arg
.
visit
([
&
](
auto
output
)
{
starts_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
ends_arg
.
visit
([
&
](
auto
output
)
{
ends_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
auto
slice_val
=
ins
->
get_operator
().
to_value
();
auto
axes_vec
=
slice_val
.
at
(
"axes"
).
to_vector
<
int64_t
>
();
m
.
replace_instruction
(
ins
,
make_op
(
"slice"
,
{{
"starts"
,
starts_vec
},
{
"ends"
,
ends_vec
},
{
"axes"
,
axes_vec
}}),
inputs
.
at
(
0
));
}
}
};
/**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant.
*/
struct
find_const_4in_slice
{
auto
matcher
()
const
{
return
match
::
name
(
"slice"
)(
match
::
nargs
(
4
),
match
::
arg
(
1
)(
match
::
is_constant
()),
match
::
arg
(
2
)(
match
::
is_constant
()),
match
::
arg
(
3
)(
match
::
is_constant
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
inputs
=
ins
->
inputs
();
argument
starts_arg
=
inputs
.
at
(
1
)
->
eval
();
argument
ends_arg
=
inputs
.
at
(
2
)
->
eval
();
argument
axes_arg
=
inputs
.
at
(
3
)
->
eval
();
if
(
not
starts_arg
.
empty
()
and
not
ends_arg
.
empty
()
and
not
axes_arg
.
empty
())
{
std
::
vector
<
int64_t
>
starts_vec
;
std
::
vector
<
int64_t
>
ends_vec
;
std
::
vector
<
int64_t
>
axes_vec
;
starts_arg
.
visit
([
&
](
auto
output
)
{
starts_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
ends_arg
.
visit
([
&
](
auto
output
)
{
ends_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
axes_arg
.
visit
([
&
](
auto
output
)
{
axes_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
m
.
replace_instruction
(
ins
,
make_op
(
"slice"
,
{{
"starts"
,
starts_vec
},
{
"ends"
,
ends_vec
},
{
"axes"
,
axes_vec
}}),
inputs
.
at
(
0
));
}
}
};
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_static_2in_broadcasts
{},
find_const_3in_slice
{},
find_const_4in_slice
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/simplify_reshapes.cpp
View file @
8d32c6b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -122,6 +122,11 @@ struct find_nop_reshapes
reshapes
.
insert
(
"pad"
);
reshapes
.
insert
(
"slice"
);
reshapes
.
insert
(
"transpose"
);
reshapes
.
insert
(
"reduce_mean"
);
reshapes
.
insert
(
"reduce_max"
);
reshapes
.
insert
(
"reduce_min"
);
reshapes
.
insert
(
"reduce_sum"
);
reshapes
.
insert
(
"reduce_prod"
);
return
match
::
name
(
reshapes
)(
match
::
same_shape
(
match
::
arg
(
0
)));
}
...
...
@@ -627,6 +632,46 @@ struct find_transpose_contiguous_reshaper_unary
}
};
// simplifies broadcast->transpose to transpose->broadcast
// in the case of a scalar, simply rewrite to broadcast
// this can allow for further optimizations with find_inner_broadcast() in simplify_algebra.cpp
struct
find_broadcast_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
arg
(
0
)(
match
::
name
(
"multibroadcast"
).
bind
(
"bcast_ins"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
transpose
=
r
.
result
;
auto
transpose_lens
=
transpose
->
get_shape
().
lens
();
auto
bcast_ins
=
r
.
instructions
[
"bcast_ins"
];
auto
input
=
bcast_ins
->
inputs
().
front
();
// scalar transformation does not need extra transpose
if
(
not
input
->
get_shape
().
scalar
())
{
// find common shape
auto
in_lens
=
input
->
get_shape
().
lens
();
int
lens_diff
=
transpose_lens
.
size
()
-
in_lens
.
size
();
// insert unsqueeze if input lens < transpose lens
if
(
lens_diff
>
0
)
{
std
::
vector
<
size_t
>
unsqueeze_axes
(
lens_diff
);
std
::
iota
(
unsqueeze_axes
.
begin
(),
unsqueeze_axes
.
end
(),
0
);
input
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
input
);
}
// apply transpose before the multibroadcast
input
=
m
.
insert_instruction
(
bcast_ins
,
transpose
->
get_operator
(),
input
);
}
auto
new_mbcast
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
transpose_lens
}}),
input
);
m
.
replace_instruction
(
transpose
,
new_mbcast
);
}
};
struct
find_slice_transpose
{
auto
matcher
()
const
...
...
@@ -784,7 +829,7 @@ struct find_transpose_slice
void
simplify_reshapes
::
apply
(
module
&
m
)
const
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
for
(
int
i
=
0
;
i
<
depth
;
i
++
)
{
match
::
find_matches
(
m
,
find_where_op
{},
...
...
@@ -799,6 +844,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_slice
{},
find_nested_concat
{},
find_transpose_slice
{},
find_broadcast_transpose
{},
find_slice_transpose
{},
find_transpose_contiguous_reshaper_unary
{});
dead_code_elimination
{}.
apply
(
m
);
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment