Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
538dbd75
Commit
538dbd75
authored
Dec 05, 2023
by
Brian Pickrell
Browse files
Merge branch 'develop' into resize_op
parents
c7161d99
e3e00547
Changes
182
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
904 additions
and
180 deletions
+904
-180
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+1
-1
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+28
-48
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+33
-16
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-0
test/gpu/fuse_mlir.cpp
test/gpu/fuse_mlir.cpp
+5
-3
test/gpu/jit.cpp
test/gpu/jit.cpp
+11
-9
test/gpu/mlir.cpp
test/gpu/mlir.cpp
+52
-29
test/include/test.hpp
test/include/test.hpp
+1
-0
test/onnx/.onnxrt-commit
test/onnx/.onnxrt-commit
+1
-1
test/onnx/averagepool_dilate_test.onnx
test/onnx/averagepool_dilate_test.onnx
+17
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+440
-26
test/onnx/maxpool_dilate_test.onnx
test/onnx/maxpool_dilate_test.onnx
+17
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+297
-47
test/onnx/qlinearaveragepool_1d_test.onnx
test/onnx/qlinearaveragepool_1d_test.onnx
+0
-0
test/onnx/qlinearaveragepool_2d_ceil_test.onnx
test/onnx/qlinearaveragepool_2d_ceil_test.onnx
+0
-0
test/onnx/qlinearaveragepool_2d_dilations_test.onnx
test/onnx/qlinearaveragepool_2d_dilations_test.onnx
+0
-0
test/onnx/qlinearaveragepool_2d_pads_count_include_pad_test.onnx
...nx/qlinearaveragepool_2d_pads_count_include_pad_test.onnx
+0
-0
test/onnx/qlinearaveragepool_2d_same_lower_test.onnx
test/onnx/qlinearaveragepool_2d_same_lower_test.onnx
+0
-0
test/onnx/qlinearaveragepool_2d_same_upper_test.onnx
test/onnx/qlinearaveragepool_2d_same_upper_test.onnx
+0
-0
test/onnx/qlinearaveragepool_2d_strides_test.onnx
test/onnx/qlinearaveragepool_2d_strides_test.onnx
+0
-0
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
538dbd75
...
...
@@ -207,7 +207,7 @@ struct implicit_conversion_op
template
<
class
U
>
constexpr
operator
U
()
const
{
return
x
;
return
static_cast
<
U
>
(
x
)
;
}
};
...
...
src/targets/gpu/mlir.cpp
View file @
538dbd75
...
...
@@ -37,7 +37,7 @@
#include <mlir-c/Pass.h>
#include <mlir-c/Support.h>
#include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION !=
3
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION !=
4
#warning "Incompatible version of rocMLIR library used, disabling"
// Only undefine when not using cppcheck
#ifndef CPPCHECK
...
...
@@ -73,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNE_LIMIT
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNING_DB
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNING_CFG
);
...
...
@@ -318,31 +319,30 @@ struct mlir_program
return
result
;
}
MlirType
make_
tensor
(
const
shape
&
s
)
const
MlirType
make_
mlir_shaped
(
const
shape
&
s
)
const
{
if
(
not
s
.
standard
())
MIGRAPHX_THROW
(
"MLIR expects all tensors to be in standard shape"
);
if
(
s
.
dynamic
())
MIGRAPHX_THROW
(
"MLIR does not support dynamic shapes"
);
std
::
vector
<
int64_t
>
lens
(
s
.
lens
().
begin
(),
s
.
lens
().
end
());
return
mlirRankedTensorTypeGet
(
lens
.
size
(),
lens
.
data
(),
make_type
(
s
.
type
()),
mlirAttributeGetNull
());
std
::
vector
<
int64_t
>
strides
(
s
.
strides
().
begin
(),
s
.
strides
().
end
());
return
rocmlirMIXRShapedTypeGet
(
lens
.
size
(),
lens
.
data
(),
strides
.
data
(),
make_type
(
s
.
type
()));
}
template
<
class
Range
>
std
::
vector
<
MlirType
>
make_
tensor
s
(
const
Range
&
r
)
std
::
vector
<
MlirType
>
make_
mlir_shaped
s
(
const
Range
&
r
)
{
std
::
vector
<
MlirType
>
result
;
std
::
transform
(
r
.
begin
(),
r
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
s
)
{
return
make_
tensor
(
s
);
return
make_
mlir_shaped
(
s
);
});
return
result
;
}
MlirType
make_function_type
(
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
shape
>&
outputs
)
{
auto
in
=
make_
tensor
s
(
inputs
);
auto
out
=
make_
tensor
s
(
outputs
);
auto
in
=
make_
mlir_shaped
s
(
inputs
);
auto
out
=
make_
mlir_shaped
s
(
outputs
);
return
mlirFunctionTypeGet
(
ctx
.
get
(),
in
.
size
(),
in
.
data
(),
out
.
size
(),
out
.
data
());
}
...
...
@@ -504,11 +504,7 @@ struct mlir_program
mlir_operation_state
&
add_results
(
const
std
::
vector
<
shape
>&
outputs
)
{
std
::
vector
<
shape
>
reshaped
(
outputs
.
size
());
std
::
transform
(
outputs
.
begin
(),
outputs
.
end
(),
reshaped
.
begin
(),
[](
const
shape
&
r
)
{
return
shape
{
r
.
type
(),
r
.
lens
()};
});
auto
x
=
prog
->
make_tensors
(
reshaped
);
auto
x
=
prog
->
make_mlir_shapeds
(
outputs
);
if
(
not
x
.
empty
())
{
mlirOperationStateAddResults
(
&
op_state
,
x
.
size
(),
x
.
data
());
...
...
@@ -581,7 +577,7 @@ struct mlir_program
std
::
vector
<
shape
>
outputs
=
m
.
get_output_shapes
();
std
::
vector
<
MlirLocation
>
arg_locs
(
inputs
.
size
(),
location
);
auto
body_inputs
=
make_
tensor
s
(
inputs
);
auto
body_inputs
=
make_
mlir_shaped
s
(
inputs
);
mlir_region
region
=
mlirRegionCreate
();
mlir_block
fbody
=
mlirBlockCreate
(
body_inputs
.
size
(),
body_inputs
.
data
(),
arg_locs
.
data
());
MlirBlock
result
=
fbody
.
get
();
...
...
@@ -607,7 +603,7 @@ struct mlir_program
return
"func.return"
;
if
(
ins
->
name
()
==
"@literal"
)
{
return
"
tosa.const
"
;
return
"
migraphx.literal
"
;
}
return
"migraphx."
+
ins
->
name
();
}
...
...
@@ -666,7 +662,8 @@ struct mlir_program
if
(
ins
->
name
()
==
"@literal"
)
{
literal
r
=
ins
->
get_literal
();
MlirType
tensor_type
=
make_tensor
(
ins
->
get_shape
());
MlirType
shaped_type
=
make_mlir_shaped
(
ins
->
get_shape
());
MlirType
tensor_type
=
rocmlirMIXRShapedTypeAsTensor
(
shaped_type
);
MlirAttribute
mlir_value_attr
=
mlirDenseElementsAttrRawBufferGet
(
tensor_type
,
r
.
get_shape
().
bytes
(),
r
.
data
());
ops
.
add_attributes
({{
"value"
,
mlir_value_attr
}});
...
...
@@ -796,7 +793,9 @@ struct mlir_program
if
(
enabled
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
{}))
tuning_mode
=
RocmlirTuningParamSetKindExhaustive
;
mlir_tuning_space
params
{
mlirRockTuningSpaceCreate
(
mmodule
.
get
(),
tuning_mode
)};
for
(
auto
i
:
range
(
mlirRockTuningGetNumParams
(
params
.
get
())))
const
auto
limit
=
value_of
(
MIGRAPHX_MLIR_TUNE_LIMIT
{},
std
::
numeric_limits
<
std
::
size_t
>::
max
());
for
(
auto
i
:
range
(
std
::
min
<
std
::
size_t
>
(
limit
,
mlirRockTuningGetNumParams
(
params
.
get
()))))
{
mlir_tuning_param
param
{
mlirRockTuningParamCreate
()};
if
(
not
mlirRockTuningParamGet
(
params
.
get
(),
i
,
param
.
get
()))
...
...
@@ -942,35 +941,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
auto
param
=
m
.
get_parameter
(
name
);
if
(
input
.
standard
())
continue
;
auto
lens
=
input
.
lens
();
auto
strides
=
input
.
strides
();
std
::
vector
<
operation
>
ops
;
if
(
input
.
transposed
())
{
auto
perm
=
find_permutation
(
input
);
auto
iperm
=
invert_permutation
(
perm
);
lens
=
reorder_dims
(
lens
,
iperm
);
strides
=
reorder_dims
(
strides
,
iperm
);
ops
.
push_back
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}));
}
if
(
input
.
broadcasted
())
{
std
::
transform
(
lens
.
begin
(),
lens
.
end
(),
strides
.
begin
(),
lens
.
begin
(),
[](
auto
len
,
auto
stride
)
->
std
::
size_t
{
if
(
stride
==
0
)
return
1
;
return
len
;
});
ops
.
push_back
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input
.
lens
()}}));
}
auto
new_param
=
std
::
accumulate
(
ops
.
begin
(),
ops
.
end
(),
m
.
add_parameter
(
name
+
".0"
,
shape
{
input
.
type
(),
lens
}),
[
&
](
auto
x
,
auto
op
)
{
return
m
.
insert_instruction
(
param
,
op
,
x
);
});
auto
new_param
=
m
.
add_parameter
(
name
+
".0"
,
input
);
m
.
replace_instruction
(
param
,
new_param
);
m
.
remove_instruction
(
param
);
}
...
...
@@ -1032,6 +1003,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
mlir_program
mp
;
mp
.
set_gpu_properties
(
migraphx_ctx
);
mp
.
parse
(
m
);
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_MLIR
{});
static
std
::
mutex
mutex
;
if
(
trace
)
{
const
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
auto
mod_op
=
mlirModuleGetOperation
(
mp
.
mmodule
.
get
());
std
::
cout
<<
mlir_print
(
&
mlirOperationPrint
,
mod_op
)
<<
std
::
endl
;
}
return
mp
.
get_tuning_config
(
exhaustive
);
}
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
538dbd75
...
...
@@ -31,6 +31,7 @@
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp>
#endif
#include <migraphx/gpu/fuse_mlir.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -124,34 +125,55 @@ struct find_add_layernorm
}
};
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
struct
pre_gemm_softmax_gemm
:
gemm_softmax_gemm
{
std
::
string
name
()
const
{
return
"gpu::pre_gemm_softmax_gemm"
;
}
};
MIGRAPHX_REGISTER_OP
(
pre_gemm_softmax_gemm
);
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
auto
is_ck_gemm
(
)
{
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
if
(
not
pre_gemm_softmax_gemm
::
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
if
(
not
enabled
(
MIGRAPHX_ENABLE_CK
{}))
return
false
;
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
if
(
not
pre_gemm_softmax_gemm
::
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
return
false
;
return
true
;
#else
(
void
)
ins
;
return
false
;
return
true
;
#endif
});
}
auto
is_mlir_gemm
()
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
not
mlir_attention_enabled
())
return
false
;
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
pre_gemm_softmax_gemm
::
is_mlir_supported_type
(
i
->
get_shape
().
type
());
});
});
}
struct
find_gemm_softmax_gemm
{
auto
matcher
()
const
{
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
(
).
bind
(
"gemm1"
)));
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
match
::
any_of
(
is_ck_gemm
(),
is_mlir_gemm
()
).
bind
(
"gemm1"
)));
auto
mul
=
match
::
name
(
"mul"
)(
match
::
nargs
(
2
),
match
::
either_arg
(
0
,
1
)(
match
::
is_constant
().
bind
(
"scale"
),
gemm1
));
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
arg
(
0
)(
mul
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
arg
(
0
)(
softmax
));
return
match
::
name
(
"dot"
)(
match
::
any_of
(
is_ck_gemm
(),
is_mlir_gemm
()).
bind
(
"gemm2"
))(
match
::
arg
(
0
)(
softmax
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
...
...
@@ -179,8 +201,6 @@ struct find_gemm_softmax_gemm
}
};
#endif
}
// namespace
void
prefuse_ops
::
apply
(
module_pass_manager
&
mpm
)
const
...
...
@@ -188,10 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match
::
find_matches
(
mpm
.
get_module
(),
find_layernorm
{});
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_add_layernorm
{});
#ifdef MIHRAPHX_USE_COMPOSABLEKERNEL
if
(
enabled
(
MIGRAPHX_ENABLE_CK
{}))
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm
{});
#endif
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm
{});
}
}
// namespace gpu
...
...
src/targets/gpu/target.cpp
View file @
538dbd75
...
...
@@ -98,6 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
ctx
.
set_exhaustive_tune_flag
(
options
.
exhaustive_tune
);
std
::
set
<
shape
::
type_t
>
unsupported_types
(
shape
::
types
().
begin
(),
shape
::
types
().
end
());
unsupported_types
.
erase
(
shape
::
type_t
::
float_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
fp8e4m3fnuz_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
half_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
...
...
test/gpu/fuse_mlir.cpp
View file @
538dbd75
...
...
@@ -144,10 +144,12 @@ TEST_CASE(int_quant_dot_tanh_fails)
auto
tanh
=
add_pointwise
(
p1
,
"main:pointwise0"
,
{
dot
},
single_pointwise
(
"tanh"
));
mm
->
add_return
({
tanh
});
}
migraphx
::
program
p2
(
p1
);
// This pass should do nothing as int32_t tanh isn't supported.
// This pass should not fuse as int32_t tanh isn't supported.
run_pass
(
p1
);
EXPECT
(
p1
==
p2
);
auto
*
mm
=
p1
.
get_main_module
();
bool
has_pointwise
=
std
::
any_of
(
mm
->
begin
(),
mm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
i
.
name
()
==
"pointwise"
;
});
EXPECT
(
has_pointwise
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
...
...
test/gpu/jit.cpp
View file @
538dbd75
...
...
@@ -139,7 +139,8 @@ const std::string math_template = R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/types.hpp>
using namespace migraphx;
namespace migraphx {
extern "C" {
__global__ void kernel(${type}* p)
{
...
...
@@ -148,6 +149,7 @@ __global__ void kernel(${type}* p)
}
}
}
int main() {}
...
...
@@ -348,18 +350,19 @@ TEST_CASE(compile_math)
auto
vec_sizes
=
{
2
,
4
,
6
};
for
(
auto
&&
t
:
migraphx
::
shape
::
types
())
{
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
fp8e4m3fnuz_type
,
migraphx
::
shape
::
tuple_type
},
t
))
if
(
contains
({
migraphx
::
shape
::
bool_type
,
migraphx
::
shape
::
tuple_type
},
t
))
continue
;
auto
name
=
migraphx
::
shape
::
cpp_type
(
t
);
if
(
t
==
migraphx
::
shape
::
half_type
)
name
.
insert
(
0
,
"migraphx::"
);
data_types
.
push_back
(
name
);
migraphx
::
transform
(
vec_sizes
,
std
::
back_inserter
(
data_types
),
[
&
](
auto
i
)
{
return
"migraphx::vec<"
+
name
+
", "
+
std
::
to_string
(
i
)
+
">"
;
});
// fp8 doesn't have vectorization support yet, therefore skip it for now.
if
(
t
!=
migraphx
::
shape
::
fp8e4m3fnuz_type
)
{
migraphx
::
transform
(
vec_sizes
,
std
::
back_inserter
(
data_types
),
[
&
](
auto
i
)
{
return
"migraphx::vec<"
+
name
+
", "
+
std
::
to_string
(
i
)
+
">"
;
});
}
}
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
5
,
2
}};
migraphx
::
gpu
::
hip_compile_options
options
;
...
...
@@ -429,7 +432,6 @@ TEST_CASE(assert_type_min_max)
min
=
std
::
to_string
(
as
.
min
());
max
=
std
::
to_string
(
as
.
max
());
}
auto
src
=
migraphx
::
interpolate_string
(
assert_template
,
{{
"type"
,
name
},
{
"max"
,
max
},
{
"min"
,
min
}});
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
5
,
2
}};
...
...
test/gpu/mlir.cpp
View file @
538dbd75
...
...
@@ -141,9 +141,9 @@ TEST_CASE(conv)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_convolution(%arg0:
tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor
<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution
(
%arg1, %arg0
)
{dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} :
(tensor
<1x8x4x4xf32
>, tensor<2x8x3x3xf32>) -> tensor
<1x2x2x2xf32>
return %0 :
tensor
<1x2x2x2xf32>
func.func @mlir_convolution(%arg0:
!migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped
<1x2x2x2xf32
, 8x4x2x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution
%arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32
, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> ->
<1x2x2x2xf32
, 8x4x2x1
>
return %0 :
!migraphx.shaped
<1x2x2x2xf32
, 8x4x2x1
>
}
}
)__migraphx__"
;
...
...
@@ -160,15 +160,38 @@ module {
EXPECT
(
verify_mlir
(
m
));
}
TEST_CASE
(
conv_nhwc
)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x1x24x8>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x1x32x8>) -> !migraphx.shaped<1x2x2x2xf32, 8x1x4x2> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x1x32x8>, <2x8x3x3xf32, 72x1x24x8> -> <1x2x2x2xf32, 8x1x4x2>
return %0 : !migraphx.shaped<1x2x2x2xf32, 8x1x4x2>
}
}
)__migraphx__"
;
migraphx
::
module
m
;
auto
x
=
m
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
4
,
4
},
{
128
,
1
,
32
,
8
}});
auto
w
=
m
.
add_parameter
(
"w"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
8
,
3
,
3
},
{
72
,
1
,
24
,
8
}});
auto
conv
=
m
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
),
x
,
w
);
m
.
add_return
({
conv
});
auto
s
=
migraphx
::
gpu
::
dump_mlir
(
m
);
// Skip test if MLIR is not enabled
if
(
s
.
empty
())
return
;
CHECK
(
encode
(
s
)
==
encode
(
mlir_output
));
EXPECT
(
verify_mlir
(
m
));
}
TEST_CASE
(
conv_add_relu
)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_convolution_add_relu(%arg0:
tensor
<1x2x2x2xf32>, %arg1:
tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor
<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution
(
%arg2, %arg1
)
{dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} :
(tensor
<1x8x4x4xf32
>, tensor<2x8x3x3xf32>) -> tensor
<1x2x2x2xf32>
%1 = migraphx.add
(
%0, %arg0
)
:
(tensor
<1x2x2x2xf32
>, tensor<1x2x2x2xf32>) -> tensor
<1x2x2x2xf32>
%2 = migraphx.relu
(
%1
)
:
(tensor
<1x2x2x2xf32
>) -> tensor
<1x2x2x2xf32>
return %2 :
tensor
<1x2x2x2xf32>
func.func @mlir_convolution_add_relu(%arg0:
!migraphx.shaped
<1x2x2x2xf32
, 8x4x2x1
>, %arg1:
!migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg2: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped
<1x2x2x2xf32
, 8x4x2x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution
%arg2, %arg1 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32
, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> ->
<1x2x2x2xf32
, 8x4x2x1
>
%1 = migraphx.add
%0, %arg0 : <1x2x2x2xf32
, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1> ->
<1x2x2x2xf32
, 8x4x2x1
>
%2 = migraphx.relu
%1 : <1x2x2x2xf32
, 8x4x2x1> ->
<1x2x2x2xf32
, 8x4x2x1
>
return %2 :
!migraphx.shaped
<1x2x2x2xf32
, 8x4x2x1
>
}
}
)__migraphx__"
;
...
...
@@ -192,10 +215,10 @@ TEST_CASE(quant_dot_add)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_quant_dot_add(%arg0:
tensor
<1x5x4xi8>, %arg1:
tensor
<1x4x3xi8>, %arg2:
tensor<1x5x3xi32>) -> tensor
<1x5x3xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_dot
(
%arg0, %arg1
)
:
(tensor
<1x5x4xi8
>, tensor<1x4x3xi8>) -> tensor
<1x5x3xi32>
%1 = migraphx.add
(
%0, %arg2
)
:
(tensor
<1x5x3xi32
>, tensor<1x5x3xi32>) -> tensor
<1x5x3xi32>
return %1 :
tensor
<1x5x3xi32>
func.func @mlir_quant_dot_add(%arg0:
!migraphx.shaped
<1x5x4xi8
, 20x4x1
>, %arg1:
!migraphx.shaped
<1x4x3xi8
, 12x3x1
>, %arg2:
!migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped
<1x5x3xi32
, 15x3x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_dot
%arg0, %arg1 : <1x5x4xi8
, 20x4x1>, <1x4x3xi8, 12x3x1> ->
<1x5x3xi32
, 15x3x1
>
%1 = migraphx.add
%0, %arg2 : <1x5x3xi32
, 15x3x1>, <1x5x3xi32, 15x3x1> ->
<1x5x3xi32
, 15x3x1
>
return %1 :
!migraphx.shaped
<1x5x3xi32
, 15x3x1
>
}
}
)__migraphx__"
;
...
...
@@ -219,10 +242,10 @@ TEST_CASE(dot_add)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_dot_add(%arg0:
tensor
<1x5x4xf32>, %arg1:
tensor
<1x4x3xf32>, %arg2:
tensor<1x5x3xf32>) -> tensor
<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot
(
%arg0, %arg1
)
:
(tensor
<1x5x4xf32
>, tensor<1x4x3xf32>) -> tensor
<1x5x3xf32>
%1 = migraphx.add
(
%0, %arg2
)
:
(tensor
<1x5x3xf32
>, tensor<1x5x3xf32>) -> tensor
<1x5x3xf32>
return %1 :
tensor
<1x5x3xf32>
func.func @mlir_dot_add(%arg0:
!migraphx.shaped
<1x5x4xf32
, 20x4x1
>, %arg1:
!migraphx.shaped
<1x4x3xf32
, 12x3x1
>, %arg2:
!migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped
<1x5x3xf32
, 15x3x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot
%arg0, %arg1 : <1x5x4xf32
, 20x4x1>, <1x4x3xf32, 12x3x1> ->
<1x5x3xf32
, 15x3x1
>
%1 = migraphx.add
%0, %arg2 : <1x5x3xf32
, 15x3x1>, <1x5x3xf32, 15x3x1> ->
<1x5x3xf32
, 15x3x1
>
return %1 :
!migraphx.shaped
<1x5x3xf32
, 15x3x1
>
}
}
)__migraphx__"
;
...
...
@@ -245,11 +268,11 @@ TEST_CASE(conv_int8_dequantize_quantize)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0:
tensor<2x8x3x3xi8>, %arg1: tensor<1x8x4x4xi8>, %arg2: tensor
<1x2x2x2xf32>, %arg3:
tensor<1x2x2x2xi32>) -> tensor
<1x2x2x2xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_convolution
(
%arg1, %arg0
)
{dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} :
(tensor
<1x8x4x4xi8
>, tensor<2x8x3x3xi8>) -> tensor
<1x2x2x2xi32>
%1 = migraphx.dequantizelinear
(
%0, %arg2, %arg3
)
:
(tensor
<1x2x2x2xi32
>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor
<1x2x2x2xf32>
%2 = migraphx.quantizelinear
(
%1, %arg2, %arg3
)
:
(tensor
<1x2x2x2xf32
>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor
<1x2x2x2xi32>
return %2 :
tensor
<1x2x2x2xi32>
func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0:
!migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped
<1x2x2x2xf32
, 8x4x2x1
>, %arg3:
!migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped
<1x2x2x2xi32
, 8x4x2x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_convolution
%arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8
, 128x16x4x1>, <2x8x3x3xi8, 72x9x3x1> ->
<1x2x2x2xi32
, 8x4x2x1
>
%1 = migraphx.dequantizelinear
%0, %arg2, %arg3 : <1x2x2x2xi32
, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> ->
<1x2x2x2xf32
, 8x4x2x1
>
%2 = migraphx.quantizelinear
%1, %arg2, %arg3 : <1x2x2x2xf32
, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> ->
<1x2x2x2xi32
, 8x4x2x1
>
return %2 :
!migraphx.shaped
<1x2x2x2xi32
, 8x4x2x1
>
}
}
)__migraphx__"
;
...
...
@@ -278,10 +301,10 @@ TEST_CASE(dot_convert)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_dot_convert(%arg0:
tensor
<1x5x4xf32>, %arg1:
tensor<1x4x3xf32>) -> tensor
<1x5x3xf16> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot
(
%arg0, %arg1
)
:
(tensor
<1x5x4xf32
>, tensor<1x4x3xf32>) -> tensor
<1x5x3xf32>
%1 = migraphx.convert
(
%0
)
{target_type = 1 : i64} :
(tensor
<1x5x3xf32
>) -> tensor
<1x5x3xf16>
return %1 :
tensor
<1x5x3xf16>
func.func @mlir_dot_convert(%arg0:
!migraphx.shaped
<1x5x4xf32
, 20x4x1
>, %arg1:
!migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped
<1x5x3xf16
, 15x3x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot
%arg0, %arg1 : <1x5x4xf32
, 20x4x1>, <1x4x3xf32, 12x3x1> ->
<1x5x3xf32
, 15x3x1
>
%1 = migraphx.convert
%0 {target_type = 1 : i64} : <1x5x3xf32
, 15x3x1> to
<1x5x3xf16
, 15x3x1
>
return %1 :
!migraphx.shaped
<1x5x3xf16
, 15x3x1
>
}
}
)__migraphx__"
;
...
...
@@ -304,10 +327,10 @@ TEST_CASE(dot_where)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_dot_where(%arg0:
tensor
<1x5x4xf32>, %arg1:
tensor
<1x4x3xf32>, %arg2:
tensor
<1x5x3xi8>, %arg3:
tensor<1x5x3xf32>) -> tensor
<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot
(
%arg0, %arg1
)
:
(tensor
<1x5x4xf32
>, tensor<1x4x3xf32>) -> tensor
<1x5x3xf32>
%1 = migraphx.where
(
%arg2, %0, %arg3
)
:
(tensor
<1x5x3xi8
>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor
<1x5x3xf32>
return %1 :
tensor
<1x5x3xf32>
func.func @mlir_dot_where(%arg0:
!migraphx.shaped
<1x5x4xf32
, 20x4x1
>, %arg1:
!migraphx.shaped
<1x4x3xf32
, 12x3x1
>, %arg2:
!migraphx.shaped
<1x5x3xi8
, 15x3x1
>, %arg3:
!migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped
<1x5x3xf32
, 15x3x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot
%arg0, %arg1 : <1x5x4xf32
, 20x4x1>, <1x4x3xf32, 12x3x1> ->
<1x5x3xf32
, 15x3x1
>
%1 = migraphx.where
%arg2, %0, %arg3 : <1x5x3xi8
, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> ->
<1x5x3xf32
, 15x3x1
>
return %1 :
!migraphx.shaped
<1x5x3xf32
, 15x3x1
>
}
}
)__migraphx__"
;
...
...
test/include/test.hpp
View file @
538dbd75
...
...
@@ -24,6 +24,7 @@
#include <atomic>
#include <algorithm>
#include <array>
#include <cassert>
#include <cstdio>
#include <cstdlib>
...
...
test/onnx/.onnxrt-commit
View file @
538dbd75
a5537f2f563d4975c7e6121a7eb260bbbfd9455a
d69842226b47e5336568103541b071447caeb9bf
test/onnx/averagepool_dilate_test.onnx
0 → 100644
View file @
538dbd75
averagepool_dilate_test:
Y
xy"AveragePool*
dilations@*
kernel_shape@*
pads@@*
strides@averagepool_dilate_testZ
x
b
y
B
\ No newline at end of file
test/onnx/gen_onnx.py
View file @
538dbd75
...
...
@@ -276,6 +276,22 @@ def averagepool_1d_test():
return
([
node
],
[
x
],
[
out
])
@
onnx_test
()
def
averagepool_dilate_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
1
,
4
,
3
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
1
,
4
,
2
])
node
=
onnx
.
helper
.
make_node
(
'AveragePool'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
],
kernel_shape
=
[
2
],
strides
=
[
1
],
pads
=
[
1
,
1
],
dilations
=
[
3
])
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
averagepool_3d_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
1
,
3
,
5
,
5
,
5
])
...
...
@@ -4882,6 +4898,22 @@ def maxpool_notset_test():
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
maxpool_dilate_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
1
,
4
,
3
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
1
,
4
,
2
])
node
=
onnx
.
helper
.
make_node
(
'MaxPool'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
],
kernel_shape
=
[
2
],
strides
=
[
1
],
pads
=
[
1
,
1
],
dilations
=
[
3
])
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
maxpool_same_upper_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
1
,
1
,
5
,
5
])
...
...
@@ -5962,6 +5994,263 @@ def qlinearadd_bcast_test():
[
sc_a
,
zero_pt_a
,
sc_b
,
zero_pt_b
,
sc_c
,
zero_pt_c
])
@
onnx_test
()
def
qlinearaveragepool_1d_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
INT8
,
[
1
,
3
,
32
])
x_scale
=
helper
.
make_tensor
(
'x_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
x_zero_point
=
helper
.
make_tensor
(
'x_zero_point'
,
TensorProto
.
INT8
,
[],
[
0
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
1
,
3
,
31
])
y_scale
=
helper
.
make_tensor
(
'y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
y_zero_point
=
helper
.
make_tensor
(
'y_zero_point'
,
TensorProto
.
INT8
,
[],
[
16
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAveragePool'
,
inputs
=
[
'x'
,
'x_scale'
,
'x_zero_point'
,
'y_scale'
,
'y_zero_point'
],
outputs
=
[
'y'
],
kernel_shape
=
[
2
],
)
return
([
node
],
[
x
],
[
y
],
[
x_scale
,
x_zero_point
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
qlinearaveragepool_2d_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
INT8
,
[
1
,
3
,
4
,
4
])
x_scale
=
helper
.
make_tensor
(
'x_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
x_zero_point
=
helper
.
make_tensor
(
'x_zero_point'
,
TensorProto
.
INT8
,
[],
[
0
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
INT8
,
[
1
,
3
,
3
,
3
])
y_scale
=
helper
.
make_tensor
(
'y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.015
])
y_zero_point
=
helper
.
make_tensor
(
'y_zero_point'
,
TensorProto
.
INT8
,
[],
[
16
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAveragePool'
,
inputs
=
[
'x'
,
'x_scale'
,
'x_zero_point'
,
'y_scale'
,
'y_zero_point'
],
outputs
=
[
'y'
],
kernel_shape
=
[
2
,
2
],
)
return
([
node
],
[
x
],
[
y
],
[
x_scale
,
x_zero_point
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
qlinearaveragepool_2d_ceil_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
UINT8
,
[
1
,
1
,
4
,
4
])
x_scale
=
helper
.
make_tensor
(
'x_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
x_zero_point
=
helper
.
make_tensor
(
'x_zero_point'
,
TensorProto
.
UINT8
,
[],
[
0
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
UINT8
,
[
1
,
1
,
2
,
2
])
y_scale
=
helper
.
make_tensor
(
'y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
y_zero_point
=
helper
.
make_tensor
(
'y_zero_point'
,
TensorProto
.
UINT8
,
[],
[
0
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAveragePool'
,
inputs
=
[
'x'
,
'x_scale'
,
'x_zero_point'
,
'y_scale'
,
'y_zero_point'
],
outputs
=
[
'y'
],
kernel_shape
=
[
3
,
3
],
strides
=
[
2
,
2
],
ceil_mode
=
True
,
)
return
([
node
],
[
x
],
[
y
],
[
x_scale
,
x_zero_point
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
qlinearaveragepool_2d_dilations_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
INT8
,
[
1
,
1
,
4
,
4
])
x_scale
=
helper
.
make_tensor
(
'x_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
x_zero_point
=
helper
.
make_tensor
(
'x_zero_point'
,
TensorProto
.
INT8
,
[],
[
0
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
INT8
,
[
1
,
1
,
2
,
2
])
y_scale
=
helper
.
make_tensor
(
'y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.25
])
y_zero_point
=
helper
.
make_tensor
(
'y_zero_point'
,
TensorProto
.
INT8
,
[],
[
84
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAveragePool'
,
inputs
=
[
'x'
,
'x_scale'
,
'x_zero_point'
,
'y_scale'
,
'y_zero_point'
],
outputs
=
[
'y'
],
kernel_shape
=
[
2
,
2
],
strides
=
[
1
,
1
],
dilations
=
[
2
,
2
],
ceil_mode
=
True
,
)
return
([
node
],
[
x
],
[
y
],
[
x_scale
,
x_zero_point
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
qlinearaveragepool_2d_pads_count_include_pad_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
INT8
,
[
1
,
3
,
4
,
4
])
x_scale
=
helper
.
make_tensor
(
'x_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
x_zero_point
=
helper
.
make_tensor
(
'x_zero_point'
,
TensorProto
.
INT8
,
[],
[
0
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
INT8
,
[
1
,
3
,
6
,
6
])
y_scale
=
helper
.
make_tensor
(
'y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.01
])
y_zero_point
=
helper
.
make_tensor
(
'y_zero_point'
,
TensorProto
.
INT8
,
[],
[
32
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAveragePool'
,
inputs
=
[
'x'
,
'x_scale'
,
'x_zero_point'
,
'y_scale'
,
'y_zero_point'
],
outputs
=
[
'y'
],
kernel_shape
=
[
3
,
3
],
pads
=
[
2
,
2
,
2
,
2
],
count_include_pad
=
1
,
)
return
([
node
],
[
x
],
[
y
],
[
x_scale
,
x_zero_point
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
qlinearaveragepool_2d_same_lower_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
UINT8
,
[
1
,
3
,
4
,
4
])
x_scale
=
helper
.
make_tensor
(
'x_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
x_zero_point
=
helper
.
make_tensor
(
'x_zero_point'
,
TensorProto
.
UINT8
,
[],
[
0
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
UINT8
,
[
1
,
3
,
4
,
4
])
y_scale
=
helper
.
make_tensor
(
'y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
y_zero_point
=
helper
.
make_tensor
(
'y_zero_point'
,
TensorProto
.
UINT8
,
[],
[
0
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAveragePool'
,
inputs
=
[
'x'
,
'x_scale'
,
'x_zero_point'
,
'y_scale'
,
'y_zero_point'
],
outputs
=
[
'y'
],
kernel_shape
=
[
2
,
2
],
auto_pad
=
"SAME_LOWER"
,
)
return
([
node
],
[
x
],
[
y
],
[
x_scale
,
x_zero_point
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
qlinearaveragepool_2d_same_upper_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
INT8
,
[
1
,
3
,
4
,
4
])
x_scale
=
helper
.
make_tensor
(
'x_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
x_zero_point
=
helper
.
make_tensor
(
'x_zero_point'
,
TensorProto
.
INT8
,
[],
[
32
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
INT8
,
[
1
,
3
,
4
,
4
])
y_scale
=
helper
.
make_tensor
(
'y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.25
])
y_zero_point
=
helper
.
make_tensor
(
'y_zero_point'
,
TensorProto
.
INT8
,
[],
[
0
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAveragePool'
,
inputs
=
[
'x'
,
'x_scale'
,
'x_zero_point'
,
'y_scale'
,
'y_zero_point'
],
outputs
=
[
'y'
],
kernel_shape
=
[
2
,
2
],
auto_pad
=
"SAME_UPPER"
,
)
return
([
node
],
[
x
],
[
y
],
[
x_scale
,
x_zero_point
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
qlinearaveragepool_2d_strides_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
INT8
,
[
1
,
3
,
8
,
8
])
x_scale
=
helper
.
make_tensor
(
'x_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
x_zero_point
=
helper
.
make_tensor
(
'x_zero_point'
,
TensorProto
.
INT8
,
[],
[
0
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
INT8
,
[
1
,
3
,
2
,
2
])
y_scale
=
helper
.
make_tensor
(
'y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
y_zero_point
=
helper
.
make_tensor
(
'y_zero_point'
,
TensorProto
.
INT8
,
[],
[
8
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAveragePool'
,
inputs
=
[
'x'
,
'x_scale'
,
'x_zero_point'
,
'y_scale'
,
'y_zero_point'
],
outputs
=
[
'y'
],
kernel_shape
=
[
5
,
5
],
strides
=
[
2
,
2
],
)
return
([
node
],
[
x
],
[
y
],
[
x_scale
,
x_zero_point
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
qlinearaveragepool_3d_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
INT8
,
[
1
,
3
,
3
,
3
,
3
])
x_scale
=
helper
.
make_tensor
(
'x_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
x_zero_point
=
helper
.
make_tensor
(
'x_zero_point'
,
TensorProto
.
INT8
,
[],
[
0
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
INT8
,
[
1
,
3
,
2
,
2
,
2
])
y_scale
=
helper
.
make_tensor
(
'y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.02
])
y_zero_point
=
helper
.
make_tensor
(
'y_zero_point'
,
TensorProto
.
INT8
,
[],
[
0
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAveragePool'
,
inputs
=
[
'x'
,
'x_scale'
,
'x_zero_point'
,
'y_scale'
,
'y_zero_point'
],
outputs
=
[
'y'
],
kernel_shape
=
[
2
,
2
,
2
],
)
return
([
node
],
[
x
],
[
y
],
[
x_scale
,
x_zero_point
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
qlinearaveragepool_notset_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
INT8
,
[
1
,
1
,
5
,
5
])
x_scale
=
helper
.
make_tensor
(
'x_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
x_zero_point
=
helper
.
make_tensor
(
'x_zero_point'
,
TensorProto
.
INT8
,
[],
[
0
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
INT8
,
[
1
,
1
,
1
,
1
])
y_scale
=
helper
.
make_tensor
(
'y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
y_zero_point
=
helper
.
make_tensor
(
'y_zero_point'
,
TensorProto
.
INT8
,
[],
[
10
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAveragePool'
,
inputs
=
[
'x'
,
'x_scale'
,
'x_zero_point'
,
'y_scale'
,
'y_zero_point'
],
outputs
=
[
'y'
],
kernel_shape
=
[
6
,
6
],
strides
=
[
2
,
2
],
pads
=
[
0
,
0
,
1
,
1
],
channels_last
=
0
,
auto_pad
=
'NOTSET'
)
return
([
node
],
[
x
],
[
y
],
[
x_scale
,
x_zero_point
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
qlinearaveragepool_nt_cip_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
UINT8
,
[
1
,
1
,
5
,
5
])
x_scale
=
helper
.
make_tensor
(
'x_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
x_zero_point
=
helper
.
make_tensor
(
'x_zero_point'
,
TensorProto
.
UINT8
,
[],
[
0
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
UINT8
,
[
1
,
1
,
1
,
1
])
y_scale
=
helper
.
make_tensor
(
'y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
y_zero_point
=
helper
.
make_tensor
(
'y_zero_point'
,
TensorProto
.
UINT8
,
[],
[
10
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAveragePool'
,
inputs
=
[
'x'
,
'x_scale'
,
'x_zero_point'
,
'y_scale'
,
'y_zero_point'
],
outputs
=
[
'y'
],
kernel_shape
=
[
6
,
6
],
strides
=
[
2
,
2
],
pads
=
[
0
,
0
,
1
,
1
],
channels_last
=
0
,
auto_pad
=
'NOTSET'
,
count_include_pad
=
1
)
return
([
node
],
[
x
],
[
y
],
[
x_scale
,
x_zero_point
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
qlinearconv_test
():
# https://xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__QLinearConv.html
...
...
@@ -6094,6 +6383,26 @@ def qlinearglobalavgpool_test():
return
([
n
],
[
x
],
[
y
],
[
sc_x
,
z_pt_x
,
sc_y
,
z_pt_y
])
@
onnx_test
()
def
qlinearleakyrelu_test
():
x
=
helper
.
make_tensor_value_info
(
'X'
,
TensorProto
.
INT8
,
[
64
])
sc_x
=
helper
.
make_tensor
(
'X_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
zero_pt_x
=
helper
.
make_tensor
(
'X_zero_point'
,
TensorProto
.
INT8
,
[],
[
0
])
sc_y
=
helper
.
make_tensor
(
'Y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
zero_pt_y
=
helper
.
make_tensor
(
'Y_zero_point'
,
TensorProto
.
INT8
,
[],
[
10
])
y
=
helper
.
make_tensor_value_info
(
'Y'
,
TensorProto
.
INT8
,
[
64
])
node
=
onnx
.
helper
.
make_node
(
'QLinearLeakyRelu'
,
inputs
=
[
'X'
,
'X_scale'
,
'X_zero_point'
,
'Y_scale'
,
'Y_zero_point'
],
outputs
=
[
'Y'
],
alpha
=
1.1
,
)
return
([
node
],
[
x
],
[
y
],
[
sc_x
,
zero_pt_x
,
sc_y
,
zero_pt_y
])
def
qlinearmatmul_1D_test
():
a
=
helper
.
make_tensor_value_info
(
'A'
,
TensorProto
.
UINT8
,
[
8
])
sc_a
=
helper
.
make_tensor
(
'A_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
...
...
@@ -6234,6 +6543,26 @@ def qlinearmul_bcast_test():
[
sc_a
,
zero_pt_a
,
sc_b
,
zero_pt_b
,
sc_c
,
zero_pt_c
])
@
onnx_test
()
def
qlinearsigmoid_test
():
x
=
helper
.
make_tensor_value_info
(
'X'
,
TensorProto
.
INT8
,
[
64
])
sc_x
=
helper
.
make_tensor
(
'X_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
zero_pt_x
=
helper
.
make_tensor
(
'X_zero_point'
,
TensorProto
.
INT8
,
[],
[
0
])
sc_y
=
helper
.
make_tensor
(
'Y_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.0035
])
zero_pt_y
=
helper
.
make_tensor
(
'Y_zero_point'
,
TensorProto
.
INT8
,
[],
[
-
128
])
y
=
helper
.
make_tensor_value_info
(
'Y'
,
TensorProto
.
INT8
,
[
64
])
node
=
onnx
.
helper
.
make_node
(
'QLinearSigmoid'
,
inputs
=
[
'X'
,
'X_scale'
,
'X_zero_point'
,
'Y_scale'
,
'Y_zero_point'
],
outputs
=
[
'Y'
],
)
return
([
node
],
[
x
],
[
y
],
[
sc_x
,
zero_pt_x
,
sc_y
,
zero_pt_y
])
@
onnx_test
()
def
quantizelinear_test
():
arg0
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
5
])
...
...
@@ -7383,8 +7712,7 @@ def scatter_none_test():
return
([
node
],
[
x
,
i
,
u
],
[
y
])
@
onnx_test
()
def
scatternd_add_test
():
def
make_scatternd_test
(
reduction
=
"none"
):
data
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
2
,
2
,
2
])
indices
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
2
,
1
,
2
])
...
...
@@ -7396,44 +7724,39 @@ def scatternd_add_test():
node
=
onnx
.
helper
.
make_node
(
'ScatterND'
,
inputs
=
[
'data'
,
'indices'
,
'updates'
],
outputs
=
[
'output'
],
reduction
=
"add"
)
reduction
=
reduction
)
return
([
node
],
[
data
,
indices
,
updates
],
[
output
])
@
onnx_test
()
def
scatternd_add_test
():
return
make_scatternd_test
(
"add"
)
@
onnx_test
()
def
scatternd_mul_test
():
data
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
2
,
2
,
2
])
indices
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
2
,
1
,
2
])
updates
=
helper
.
make_tensor_value_info
(
'updates'
,
TensorProto
.
FLOAT
,
[
2
,
1
,
2
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
2
,
2
,
2
])
return
make_scatternd_test
(
"mul"
)
node
=
onnx
.
helper
.
make_node
(
'ScatterND'
,
inputs
=
[
'data'
,
'indices'
,
'updates'
],
outputs
=
[
'output'
],
reduction
=
"mul"
)
return
([
node
],
[
data
,
indices
,
updates
],
[
output
])
@
onnx_test
()
def
scatternd_max_test
():
return
make_scatternd_test
(
"max"
)
@
onnx_test
()
def
scatternd_min_test
():
return
make_scatternd_test
(
"min"
)
@
onnx_test
()
def
scatternd_test
():
data
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
2
,
2
,
2
])
indices
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
2
,
1
,
2
])
updates
=
helper
.
make_tensor_value_info
(
'updates'
,
TensorProto
.
FLOAT
,
[
2
,
1
,
2
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
2
,
2
,
2
])
return
make_scatternd_test
()
node
=
onnx
.
helper
.
make_node
(
'ScatterND'
,
inputs
=
[
'data'
,
'indices'
,
'updates'
],
outputs
=
[
'output'
])
return
([
node
],
[
data
,
indices
,
updates
],
[
output
])
@
onnx_test
()
def
scatternd_invalid_reduction_test
():
return
make_scatternd_test
(
"invalid"
)
@
onnx_test
()
...
...
@@ -9220,6 +9543,97 @@ def undefined_test():
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
unique_dynamic_sorted_test
():
x
=
helper
.
make_tensor_value_info
(
'X'
,
TensorProto
.
FLOAT
,
[
6
])
y
=
helper
.
make_tensor_value_info
(
'Y'
,
TensorProto
.
FLOAT
,
[
4
])
y_ind
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
4
])
x_ind
=
helper
.
make_tensor_value_info
(
'inverse_indices'
,
TensorProto
.
INT64
,
[
6
])
count
=
helper
.
make_tensor_value_info
(
'counts'
,
TensorProto
.
INT64
,
[
4
])
node
=
onnx
.
helper
.
make_node
(
'Unique'
,
inputs
=
[
'X'
],
outputs
=
[
'Y'
,
'indices'
,
'inverse_indices'
,
'counts'
],
axis
=
0
,
sorted
=
1
)
return
([
node
],
[
x
],
[
y
,
y_ind
,
x_ind
,
count
])
@
onnx_test
()
def
unique_dynamic_sorted_3D_test
():
x
=
helper
.
make_tensor_value_info
(
'X'
,
TensorProto
.
INT64
,
[
4
,
4
,
4
])
y
=
helper
.
make_tensor_value_info
(
'Y'
,
TensorProto
.
INT64
,
[
16
])
y_ind
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
16
])
x_ind
=
helper
.
make_tensor_value_info
(
'inverse_indices'
,
TensorProto
.
INT64
,
[
64
])
count
=
helper
.
make_tensor_value_info
(
'counts'
,
TensorProto
.
INT64
,
[
16
])
node
=
onnx
.
helper
.
make_node
(
'Unique'
,
inputs
=
[
'X'
],
outputs
=
[
'Y'
,
'indices'
,
'inverse_indices'
,
'counts'
],
sorted
=
1
)
return
([
node
],
[
x
],
[
y
,
y_ind
,
x_ind
,
count
])
@
onnx_test
()
def
unique_dynamic_unsorted_test
():
x
=
helper
.
make_tensor_value_info
(
'X'
,
TensorProto
.
FLOAT
,
[
6
])
y
=
helper
.
make_tensor_value_info
(
'Y'
,
TensorProto
.
FLOAT
,
[
4
])
y_ind
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
4
])
x_ind
=
helper
.
make_tensor_value_info
(
'inverse_indices'
,
TensorProto
.
INT64
,
[
6
])
count
=
helper
.
make_tensor_value_info
(
'counts'
,
TensorProto
.
INT64
,
[
4
])
node
=
onnx
.
helper
.
make_node
(
'Unique'
,
inputs
=
[
'X'
],
outputs
=
[
'Y'
,
'indices'
,
'inverse_indices'
,
'counts'
],
axis
=
0
,
sorted
=
0
)
return
([
node
],
[
x
],
[
y
,
y_ind
,
x_ind
,
count
])
@
onnx_test
()
def
unique_sorted_test
():
x
=
helper
.
make_tensor
(
'X'
,
TensorProto
.
FLOAT
,
[
6
],
[
2
,
1
,
1
,
3
,
4
,
3
])
y
=
helper
.
make_tensor_value_info
(
'Y'
,
TensorProto
.
FLOAT
,
[
4
])
y_ind
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
4
])
x_ind
=
helper
.
make_tensor_value_info
(
'inverse_indices'
,
TensorProto
.
INT64
,
[
6
])
count
=
helper
.
make_tensor_value_info
(
'counts'
,
TensorProto
.
INT64
,
[
4
])
node
=
onnx
.
helper
.
make_node
(
'Unique'
,
inputs
=
[
'X'
],
outputs
=
[
'Y'
,
'indices'
,
'inverse_indices'
,
'counts'
],
axis
=
0
,
sorted
=
1
)
return
([
node
],
[],
[
y
,
y_ind
,
x_ind
,
count
],
[
x
])
@
onnx_test
()
def
unique_unsorted_test
():
x
=
helper
.
make_tensor
(
'X'
,
TensorProto
.
FLOAT
,
[
6
],
[
2
,
1
,
1
,
3
,
4
,
3
])
y
=
helper
.
make_tensor_value_info
(
'Y'
,
TensorProto
.
FLOAT
,
[
4
])
y_ind
=
helper
.
make_tensor_value_info
(
'indices'
,
TensorProto
.
INT64
,
[
4
])
x_ind
=
helper
.
make_tensor_value_info
(
'inverse_indices'
,
TensorProto
.
INT64
,
[
6
])
count
=
helper
.
make_tensor_value_info
(
'counts'
,
TensorProto
.
INT64
,
[
4
])
node
=
onnx
.
helper
.
make_node
(
'Unique'
,
inputs
=
[
'X'
],
outputs
=
[
'Y'
,
'indices'
,
'inverse_indices'
,
'counts'
],
axis
=
0
,
sorted
=
0
)
return
([
node
],
[],
[
y
,
y_ind
,
x_ind
,
count
],
[
x
])
@
onnx_test
()
def
unknown_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
2
,
3
,
4
,
5
])
...
...
test/onnx/maxpool_dilate_test.onnx
0 → 100644
View file @
538dbd75
maxpool_dilate_test:
U
xy"MaxPool*
dilations@*
kernel_shape@*
pads@@*
strides@maxpool_dilate_testZ
x
b
y
B
\ No newline at end of file
test/onnx/onnx_test.cpp
View file @
538dbd75
...
...
@@ -296,13 +296,32 @@ TEST_CASE(averagepool_1d_test)
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{
"padding"
,
{
0
,
0
}},
{
"stride"
,
{
1
}},
{"lengths", {3}}}),
{
"lengths"
,
{
3
}},
{
"dilations"
,
{
1
}}}),
l0
);
auto
prog
=
optimize_onnx
(
"averagepool_1d_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
averagepool_dilate_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
3
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"pooling"
,
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
1
}},
{
"lengths"
,
{
2
}},
{
"dilations"
,
{
3
}}}),
input
);
auto
prog
=
optimize_onnx
(
"averagepool_dilate_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
averagepool_3d_test
)
{
migraphx
::
program
p
;
...
...
@@ -312,7 +331,8 @@ TEST_CASE(averagepool_3d_test)
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{
"padding"
,
{
0
,
0
,
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
,
1
}},
{"lengths", {3, 3, 3}}}),
{
"lengths"
,
{
3
,
3
,
3
}},
{
"dilations"
,
{
1
,
1
,
1
}}}),
l0
);
auto
prog
=
optimize_onnx
(
"averagepool_3d_test.onnx"
);
...
...
@@ -332,6 +352,7 @@ TEST_CASE(averagepool_dyn_test)
{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{
"stride"
,
{
2
,
2
,
2
}},
{
"lengths"
,
{
3
,
3
,
3
}},
{
"dilations"
,
{
1
,
1
,
1
}},
{
"padding"
,
{
1
,
1
,
1
,
1
,
1
,
1
}},
{
"padding_mode"
,
0
},
}),
...
...
@@ -357,6 +378,7 @@ TEST_CASE(averagepool_dyn_autopad_test)
{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{
"stride"
,
{
2
,
2
,
2
}},
{
"lengths"
,
{
3
,
3
,
3
}},
{
"dilations"
,
{
1
,
1
,
1
}},
{
"padding"
,
{
0
,
0
,
0
,
0
,
0
,
0
}},
{
"padding_mode"
,
migraphx
::
op
::
padding_mode_t
::
same_upper
},
}),
...
...
@@ -394,7 +416,8 @@ TEST_CASE(averagepool_notset_test)
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{
"padding"
,
{
2
,
2
,
2
,
2
}},
{
"stride"
,
{
2
,
2
}},
{"lengths", {6, 6}}}),
{
"lengths"
,
{
6
,
6
}},
{
"dilations"
,
{
1
,
1
}}}),
input
);
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
,
3
}},
{
"starts"
,
{
1
,
1
}},
{
"ends"
,
{
2
,
2
}}}),
ins
);
...
...
@@ -415,7 +438,8 @@ TEST_CASE(averagepool_nt_cip_test)
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
2
,
2
}},
{"lengths", {6, 6}}}),
{
"lengths"
,
{
6
,
6
}},
{
"dilations"
,
{
1
,
1
}}}),
ins_pad
);
mm
->
add_return
({
ret
});
...
...
@@ -437,6 +461,7 @@ TEST_CASE(averagepool_same_lower_test)
{
"padding"
,
{
1
,
1
,
1
,
1
}},
{
"stride"
,
{
1
,
1
}},
{
"lengths"
,
{
2
,
2
}},
{
"dilations"
,
{
1
,
1
}},
{
"padding_mode"
,
migraphx
::
op
::
padding_mode_t
::
default_
},
}),
input
);
...
...
@@ -459,7 +484,8 @@ TEST_CASE(averagepool_sl_cip_test)
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{"lengths", {2, 2}}}),
{
"lengths"
,
{
2
,
2
}},
{
"dilations"
,
{
1
,
1
}}}),
ins_pad
);
mm
->
add_return
({
ret
});
auto
prog
=
migraphx
::
parse_onnx
(
"averagepool_sl_cip_test.onnx"
);
...
...
@@ -476,7 +502,8 @@ TEST_CASE(averagepool_same_upper_test)
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{
"padding"
,
{
1
,
1
,
1
,
1
}},
{
"stride"
,
{
1
,
1
}},
{"lengths", {2, 2}}}),
{
"lengths"
,
{
2
,
2
}},
{
"dilations"
,
{
1
,
1
}}}),
input
);
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
,
3
}},
{
"starts"
,
{
1
,
1
}},
{
"ends"
,
{
6
,
6
}}}),
ins
);
...
...
@@ -1307,7 +1334,8 @@ TEST_CASE(conv_bn_relu_maxpool_test)
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
max
},
{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
2
,
2
}},
{"lengths", {2, 2}}}),
{
"lengths"
,
{
2
,
2
}},
{
"dilations"
,
{
1
,
1
}}}),
l7
);
auto
prog
=
optimize_onnx
(
"conv_bn_relu_maxpool_test.onnx"
);
...
...
@@ -1505,7 +1533,8 @@ TEST_CASE(conv_relu_maxpool_test)
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
max
},
{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
2
,
2
}},
{"lengths", {2, 2}}}),
{
"lengths"
,
{
2
,
2
}},
{
"dilations"
,
{
1
,
1
}}}),
l6
);
auto
prog
=
optimize_onnx
(
"conv_relu_maxpool_test.onnx"
);
...
...
@@ -1530,7 +1559,8 @@ TEST_CASE(conv_relu_maxpool_x2_test)
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
max
},
{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
2
,
2
}},
{"lengths", {2, 2}}}),
{
"lengths"
,
{
2
,
2
}},
{
"dilations"
,
{
1
,
1
}}}),
l6
);
auto
l8
=
mm
->
add_parameter
(
"3"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
5
,
5
}});
...
...
@@ -1546,7 +1576,8 @@ TEST_CASE(conv_relu_maxpool_x2_test)
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
max
},
{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
2
,
2
}},
{"lengths", {2, 2}}}),
{
"lengths"
,
{
2
,
2
}},
{
"dilations"
,
{
1
,
1
}}}),
l13
);
auto
prog
=
optimize_onnx
(
"conv_relu_maxpool_x2_test.onnx"
);
...
...
@@ -4245,6 +4276,7 @@ TEST_CASE(lppool_l1_test)
{
"padding"
,
{
0
,
0
}},
{
"stride"
,
{
1
}},
{
"lengths"
,
{
3
}},
{
"dilations"
,
{
1
}},
{
"lp_order"
,
1
}}),
l0
);
auto
prog
=
optimize_onnx
(
"lppool_l1_test.onnx"
);
...
...
@@ -4261,6 +4293,7 @@ TEST_CASE(lppool_l2_test)
{
"padding"
,
{
0
,
0
}},
{
"stride"
,
{
1
}},
{
"lengths"
,
{
3
}},
{
"dilations"
,
{
1
}},
{
"lp_order"
,
2
}}),
l0
);
auto
prog
=
optimize_onnx
(
"lppool_l2_test.onnx"
);
...
...
@@ -4513,7 +4546,8 @@ TEST_CASE(maxpool_notset_test)
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
max
},
{
"padding"
,
{
0
,
0
,
1
,
1
}},
{
"stride"
,
{
2
,
2
}},
{"lengths", {6, 6}}}),
{
"lengths"
,
{
6
,
6
}},
{
"dilations"
,
{
1
,
1
}}}),
input
);
auto
prog
=
optimize_onnx
(
"maxpool_notset_test.onnx"
);
...
...
@@ -4521,6 +4555,24 @@ TEST_CASE(maxpool_notset_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
maxpool_dilate_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
input
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
3
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"pooling"
,
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
max
},
{
"padding"
,
{
1
,
1
}},
{
"stride"
,
{
1
}},
{
"lengths"
,
{
2
}},
{
"dilations"
,
{
3
}}}),
input
);
auto
prog
=
optimize_onnx
(
"maxpool_dilate_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
maxpool_same_upper_test
)
{
migraphx
::
program
p
;
...
...
@@ -4530,7 +4582,8 @@ TEST_CASE(maxpool_same_upper_test)
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
max
},
{
"padding"
,
{
0
,
0
,
1
,
1
}},
{
"stride"
,
{
1
,
1
}},
{"lengths", {2, 2}}}),
{
"lengths"
,
{
2
,
2
}},
{
"dilations"
,
{
1
,
1
}}}),
input
);
auto
prog
=
optimize_onnx
(
"maxpool_same_upper_test.onnx"
);
...
...
@@ -4773,8 +4826,9 @@ TEST_CASE(multinomial_test)
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
1
}};
std
::
vector
<
float
>
seed_data
=
{
seed
};
auto
seed_input
=
mm
->
add_literal
(
migraphx
::
literal
(
s
,
seed_data
));
auto rand_dummy =
mm->add_literal(migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}});
auto
rand_dummy
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
sample_size
}},
std
::
vector
<
float
>
(
batch_size
*
sample_size
)});
auto
randoms
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"random_uniform"
),
seed_input
,
rand_dummy
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multinomial"
),
cdf
,
randoms
);
...
...
@@ -4925,8 +4979,9 @@ TEST_CASE(multinomial_int64_test)
auto
seed_input
=
mm
->
add_literal
(
migraphx
::
literal
(
s
,
data
));
// static size
auto rand_dummy =
mm->add_literal(migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}});
auto
rand_dummy
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
sample_size
}},
std
::
vector
<
float
>
(
batch_size
*
sample_size
)});
auto
randoms
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"random_uniform"
),
seed_input
,
rand_dummy
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multinomial"
,
{{
"dtype"
,
dtype
}}),
cdf
,
randoms
);
auto
prog
=
optimize_onnx
(
"multinomial_int64_test.onnx"
);
...
...
@@ -5542,6 +5597,54 @@ TEST_CASE(qlinearadd_test)
EXPECT
(
p
.
sort
()
==
prog
.
sort
());
}
TEST_CASE
(
qlinearaveragepool_notset_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
sc_x
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
0.5
}});
auto
z_pt_x
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
int8_type
,
{
0
}});
auto
sc_y
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
0.5
}});
auto
z_pt_y
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
int8_type
,
{
10
}});
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
int8_type
,
{
1
,
1
,
5
,
5
}});
auto
scale_x_bcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
1
,
5
,
5
}}}),
sc_x
);
auto
z_pt_x_bcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
1
,
5
,
5
}}}),
z_pt_x
);
auto
fp_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
x
,
scale_x_bcast
,
z_pt_x_bcast
);
auto
fp_y
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"pooling"
,
{{
"mode"
,
migraphx
::
op
::
pooling_mode
::
average
},
{
"padding"
,
{
2
,
2
,
2
,
2
}},
{
"stride"
,
{
2
,
2
}},
{
"lengths"
,
{
6
,
6
}}}),
fp_x
);
fp_y
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
,
3
}},
{
"starts"
,
{
1
,
1
}},
{
"ends"
,
{
2
,
2
}}}),
fp_y
);
auto
scale_y_bcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
1
,
1
,
1
}}}),
sc_y
);
auto
z_pt_y_bcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
1
,
1
,
1
}}}),
z_pt_y
);
auto
y
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
fp_y
,
scale_y_bcast
,
z_pt_y_bcast
);
mm
->
add_return
({
y
});
auto
prog
=
migraphx
::
parse_onnx
(
"qlinearaveragepool_notset_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
qlinearconv_test
)
{
migraphx
::
program
p
;
...
...
@@ -5642,6 +5745,46 @@ TEST_CASE(qlinearglobalavgpool_test)
EXPECT
(
p
.
sort
()
==
prog
.
sort
());
}
TEST_CASE
(
qlinearleakyrelu_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"X"
,
{
migraphx
::
shape
::
int8_type
,
{
64
}});
auto
sc_x
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
0.05
}});
auto
z_pt_x
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
int8_type
,
{
0
}});
auto
sc_y
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
0.05
}});
auto
z_pt_y
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
int8_type
,
{
10
}});
auto
scale_x_bcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
64
}}}),
sc_x
);
auto
z_pt_x_bcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
64
}}}),
z_pt_x
);
auto
fp_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
x
,
scale_x_bcast
,
z_pt_x_bcast
);
auto
fp_y
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"leaky_relu"
,
{{
"alpha"
,
1.1
}}),
fp_x
);
auto
scale_y_bcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
64
}}}),
sc_y
);
auto
z_pt_y_bcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
64
}}}),
z_pt_y
);
auto
y
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
fp_y
,
scale_y_bcast
,
z_pt_y_bcast
);
mm
->
add_return
({
y
});
auto
prog
=
migraphx
::
parse_onnx
(
"qlinearleakyrelu_test.onnx"
);
EXPECT
(
p
.
sort
()
==
prog
.
sort
());
}
TEST_CASE
(
qlinearmatmul_1D_test
)
{
migraphx
::
program
p
;
...
...
@@ -5807,6 +5950,46 @@ TEST_CASE(qlinearmul_test)
EXPECT
(
p
.
sort
()
==
prog
.
sort
());
}
TEST_CASE
(
qlinearsigmoid_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"X"
,
{
migraphx
::
shape
::
int8_type
,
{
64
}});
auto
sc_x
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
0.05
}});
auto
z_pt_x
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
int8_type
,
{
0
}});
auto
sc_y
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
0.0035
}});
auto
z_pt_y
=
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
int8_type
,
{
-
128
}});
auto
scale_x_bcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
64
}}}),
sc_x
);
auto
z_pt_x_bcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
64
}}}),
z_pt_x
);
auto
fp_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
x
,
scale_x_bcast
,
z_pt_x_bcast
);
auto
fp_y
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"sigmoid"
),
fp_x
);
auto
scale_y_bcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
64
}}}),
sc_y
);
auto
z_pt_y_bcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
64
}}}),
z_pt_y
);
auto
y
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
fp_y
,
scale_y_bcast
,
z_pt_y_bcast
);
mm
->
add_return
({
y
});
auto
prog
=
migraphx
::
parse_onnx
(
"qlinearsigmoid_test.onnx"
);
EXPECT
(
p
.
sort
()
==
prog
.
sort
());
}
migraphx
::
instruction_ref
insert_quantizelinear_clip
(
migraphx
::
module
&
m
,
const
migraphx
::
instruction_ref
ins
,
const
migraphx
::
instruction_ref
round
,
...
...
@@ -7094,20 +7277,35 @@ TEST_CASE(scatter_none_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE(
scatternd_test)
void
scatternd_test
_base
(
const
std
::
string
&
reduction
,
const
std
::
string
&
onnx_file
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
2
}});
auto
l1
=
mm
->
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
1
,
2
}});
auto
l2
=
mm
->
add_parameter
(
"updates"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
2
}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_
none"
), l0, l1, l2);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_
"
+
reduction
),
l0
,
l1
,
l2
);
mm
->
add_return
({
r
});
auto prog = migraphx::parse_onnx(
"scatternd_test.onnx"
);
auto
prog
=
migraphx
::
parse_onnx
(
onnx_file
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
scatternd_test
)
{
scatternd_test_base
(
"none"
,
"scatternd_test.onnx"
);
}
TEST_CASE
(
scatternd_add_test
)
{
scatternd_test_base
(
"add"
,
"scatternd_add_test.onnx"
);
}
TEST_CASE
(
scatternd_mul_test
)
{
scatternd_test_base
(
"mul"
,
"scatternd_mul_test.onnx"
);
}
TEST_CASE
(
scatternd_max_test
)
{
scatternd_test_base
(
"max"
,
"scatternd_max_test.onnx"
);
}
TEST_CASE
(
scatternd_min_test
)
{
scatternd_test_base
(
"min"
,
"scatternd_min_test.onnx"
);
}
TEST_CASE
(
scatternd_invalid_reduction_test
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"scatternd_invalid_reduction_test.onnx"
);
}));
}
TEST_CASE
(
scatternd_dyn_test
)
{
// dynamic input.
...
...
@@ -7131,34 +7329,6 @@ TEST_CASE(scatternd_dyn_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE(scatternd_add_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
auto l2 = mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_add"), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_add_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(scatternd_mul_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
auto l2 = mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_mul"), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_mul_test.onnx");
EXPECT(p == prog);
}
TEST_CASE
(
selu_test
)
{
migraphx
::
program
p
;
...
...
@@ -8436,6 +8606,86 @@ TEST_CASE(undefined_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
unique_dynamic_sorted_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
}};
auto
x
=
mm
->
add_parameter
(
"X"
,
s
);
auto
out
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unique"
,
{{
"sorted"
,
1
},
{
"axis"
,
0
}}),
x
);
auto
y
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
out
);
auto
y_ind
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
1
}}),
out
);
auto
x_ind
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
2
}}),
out
);
auto
count
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
3
}}),
out
);
mm
->
add_return
({
y
,
y_ind
,
x_ind
,
count
});
auto
prog
=
migraphx
::
parse_onnx
(
"unique_dynamic_sorted_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
unique_dynamic_sorted_3D_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
int64_type
,
{
4
,
4
,
4
}};
auto
x
=
mm
->
add_parameter
(
"X"
,
s
);
auto
out
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unique"
,
{{
"sorted"
,
1
}}),
x
);
auto
y
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
out
);
auto
y_ind
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
1
}}),
out
);
auto
x_ind
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
2
}}),
out
);
auto
count
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
3
}}),
out
);
mm
->
add_return
({
y
,
y_ind
,
x_ind
,
count
});
auto
prog
=
migraphx
::
parse_onnx
(
"unique_dynamic_sorted_3D_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
unique_sorted_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s_x
{
migraphx
::
shape
::
float_type
,
{
6
}};
std
::
vector
<
float
>
x_data
=
{
2
,
1
,
1
,
3
,
4
,
3
};
auto
x
=
mm
->
add_literal
(
migraphx
::
literal
(
s_x
,
x_data
));
auto
out
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unique"
,
{{
"sorted"
,
1
},
{
"axis"
,
0
}}),
x
);
auto
y
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
out
);
auto
y_idx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
1
}}),
out
);
auto
x_idx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
2
}}),
out
);
auto
count
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
3
}}),
out
);
mm
->
add_return
({
y
,
y_idx
,
x_idx
,
count
});
auto
prog
=
migraphx
::
parse_onnx
(
"unique_sorted_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
unique_unsorted_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s_x
{
migraphx
::
shape
::
float_type
,
{
6
}};
std
::
vector
<
float
>
x_data
=
{
2
,
1
,
1
,
3
,
4
,
3
};
auto
x
=
mm
->
add_literal
(
migraphx
::
literal
(
s_x
,
x_data
));
auto
out
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unique"
,
{{
"sorted"
,
0
},
{
"axis"
,
0
}}),
x
);
auto
y
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
out
);
auto
y_idx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
1
}}),
out
);
auto
x_idx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
2
}}),
out
);
auto
count
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
3
}}),
out
);
mm
->
add_return
({
y
,
y_idx
,
x_idx
,
count
});
auto
prog
=
migraphx
::
parse_onnx
(
"unique_unsorted_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
unknown_test
)
{
migraphx
::
program
p
;
...
...
test/onnx/qlinearaveragepool_1d_test.onnx
0 → 100644
View file @
538dbd75
File added
test/onnx/qlinearaveragepool_2d_ceil_test.onnx
0 → 100644
View file @
538dbd75
File added
test/onnx/qlinearaveragepool_2d_dilations_test.onnx
0 → 100644
View file @
538dbd75
File added
test/onnx/qlinearaveragepool_2d_pads_count_include_pad_test.onnx
0 → 100644
View file @
538dbd75
File added
test/onnx/qlinearaveragepool_2d_same_lower_test.onnx
0 → 100644
View file @
538dbd75
File added
test/onnx/qlinearaveragepool_2d_same_upper_test.onnx
0 → 100644
View file @
538dbd75
File added
test/onnx/qlinearaveragepool_2d_strides_test.onnx
0 → 100644
View file @
538dbd75
File added
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment