Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cac6c759
Commit
cac6c759
authored
Dec 13, 2023
by
Paul
Browse files
Merge
parents
4bde67c4
a60bdb67
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
472 additions
and
282 deletions
+472
-282
src/quantize_8bits.cpp
src/quantize_8bits.cpp
+18
-10
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+9
-3
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+5
-0
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+1
-1
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+2
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-0
src/targets/ref/CMakeLists.txt
src/targets/ref/CMakeLists.txt
+0
-5
src/targets/ref/gemm.cpp
src/targets/ref/gemm.cpp
+0
-157
src/targets/ref/include/migraphx/ref/gemm.hpp
src/targets/ref/include/migraphx/ref/gemm.hpp
+0
-46
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+6
-17
test/onnx/.onnxrt-commit
test/onnx/.onnxrt-commit
+1
-1
test/onnx/dynamicquantizelinear_1d_test.onnx
test/onnx/dynamicquantizelinear_1d_test.onnx
+19
-0
test/onnx/dynamicquantizelinear_2d_test.onnx
test/onnx/dynamicquantizelinear_2d_test.onnx
+19
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+84
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+101
-4
test/onnx/qlinearconcat_3d_test.onnx
test/onnx/qlinearconcat_3d_test.onnx
+0
-0
test/onnx/qlinearconcat_test.onnx
test/onnx/qlinearconcat_test.onnx
+0
-0
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+127
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+78
-37
No files found.
src/quantize_
int8
.cpp
→
src/quantize_
8bits
.cpp
View file @
cac6c759
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -25,7 +25,7 @@
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_
int8
.hpp>
#include <migraphx/quantize_
8bits
.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
...
...
@@ -41,8 +41,6 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_INT8_QUANTIZATION_PARAMS
)
static
std
::
vector
<
shape
::
type_t
>&
get_quantizable_type
()
{
static
std
::
vector
<
shape
::
type_t
>
quantable_types
=
{
...
...
@@ -50,7 +48,7 @@ static std::vector<shape::type_t>& get_quantizable_type()
return
quantable_types
;
}
void
quantize_
int8
_pass
::
apply
(
module
&
m
)
const
// NOLINT
void
quantize_
8bits
_pass
::
apply
(
module
&
m
)
const
// NOLINT
{
const
auto
&
quantizable_types
=
get_quantizable_type
();
for
(
auto
ins
:
iterator_for
(
m
))
...
...
@@ -66,9 +64,10 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
auto
input
=
ins
->
inputs
().
front
();
auto
s
=
input
->
get_shape
();
if
(
contains
(
quantizable_types
,
s
.
type
())
and
s
.
type
()
!=
shape
::
int8_type
)
if
(
contains
(
quantizable_types
,
s
.
type
())
and
s
.
type
()
!=
precision
)
{
auto
zero_point
=
m
.
add_literal
(
static_cast
<
int8_t
>
(
param
.
second
));
auto
zero_point
=
m
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
precision
},
{
param
.
second
}});
auto
scale
=
m
.
add_literal
(
literal
({
s
.
type
()},
{
1.0
f
/
param
.
first
}));
const
auto
&
lens
=
s
.
lens
();
scale
=
...
...
@@ -87,9 +86,11 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
void
capture_arguments_pass
::
apply
(
module
&
m
)
const
// NOLINT
{
assert
(
param_index
!=
nullptr
);
const
auto
&
quantizable_types
=
get_quantizable_type
();
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
not
contains
(
ins_names
,
ins
->
name
()))
if
(
(
not
contains
(
ins_names
,
ins
->
name
()))
or
(
ins
->
name
()
==
"convert"
))
{
continue
;
}
...
...
@@ -98,8 +99,15 @@ void capture_arguments_pass::apply(module& m) const // NOLINT
std
::
vector
<
instruction_ref
>
new_args
;
for
(
auto
input
:
inputs
)
{
auto
new_in
=
m
.
insert_instruction
(
ins
,
op
::
capture
{(
*
param_index
)
++
,
f
},
input
);
new_args
.
push_back
(
new_in
);
if
(
contains
(
quantizable_types
,
input
->
get_shape
().
type
()))
{
auto
new_in
=
m
.
insert_instruction
(
ins
,
op
::
capture
{(
*
param_index
)
++
,
f
},
input
);
new_args
.
push_back
(
new_in
);
}
else
{
new_args
.
push_back
(
input
);
}
}
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
new_args
);
}
...
...
src/simplify_qdq.cpp
View file @
cac6c759
...
...
@@ -210,9 +210,15 @@ bool compare_literals(instruction_ref ins1, instruction_ref ins2)
bool
diff_shapes_equal_vals
=
false
;
visit_all
(
ins1
->
get_literal
(),
ins2
->
get_literal
())([
&
](
const
auto
l1
,
const
auto
l2
)
{
diff_shapes_equal_vals
=
std
::
all_of
(
l1
.
begin
()
+
1
,
l1
.
end
(),
[
&
](
auto
v
)
{
return
float_equal
(
v
,
l1
.
front
());
})
and
std
::
all_of
(
l2
.
begin
(),
l2
.
end
(),
[
&
](
auto
v
)
{
return
float_equal
(
v
,
l1
.
front
());
});
std
::
all_of
(
l1
.
begin
()
+
1
,
l1
.
end
(),
[
&
](
auto
v
)
{
return
((
float_equal
(
v
,
l1
.
front
()))
or
(
std
::
isinf
(
l1
.
front
())
and
std
::
isinf
(
v
)));
})
and
std
::
all_of
(
l2
.
begin
(),
l2
.
end
(),
[
&
](
auto
v
)
{
return
((
float_equal
(
v
,
l1
.
front
()))
or
(
std
::
isinf
(
l1
.
front
())
and
std
::
isinf
(
v
)));
});
});
return
(
x
==
y
)
or
diff_shapes_equal_vals
;
...
...
src/simplify_reshapes.cpp
View file @
cac6c759
...
...
@@ -183,6 +183,11 @@ struct find_nested_convert
auto
x
=
ins
->
inputs
().
front
();
auto
input
=
x
->
inputs
().
front
();
while
(
input
->
name
()
==
"convert"
)
{
input
=
input
->
inputs
().
front
();
}
if
(
ins
->
get_shape
()
!=
input
->
get_shape
())
return
;
...
...
src/targets/gpu/gemm_impl.cpp
View file @
cac6c759
...
...
@@ -195,7 +195,7 @@ struct gemm_impl
ldd
=
is_3inputs
?
input_shapes
[
3
].
strides
()[
dim_0
]
:
ldc
;
arg_type
=
get_type
(
input_shapes
[
0
].
type
());
output_type
=
arg
_type
;
output_type
=
get
_type
(
input_shapes
[
2
].
type
())
;
if
(
output_type
==
rocblas_datatype_i8_r
)
{
output_type
=
rocblas_datatype_i32_r
;
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
cac6c759
...
...
@@ -112,7 +112,7 @@ struct rocblas_gemm
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
if
(
this
->
name
()
==
"gpu::gemm"
)
if
(
this
->
name
()
==
"gpu::gemm"
or
output_shape
.
type
()
==
migraphx
::
shape
::
float_type
)
{
gemm_compute
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
cac6c759
...
...
@@ -26,10 +26,12 @@
#include <migraphx/kernels/integral_constant.hpp>
// Similiar to decltype(auto) except it will propagate any substitution failures
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// Lifts an expression into a function object so it can be passed to a higher-order function
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
...
...
src/targets/gpu/target.cpp
View file @
cac6c759
...
...
@@ -110,6 +110,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
if
(
not
gpu
::
rocblas_fp8_available
())
{
unsupported_fp8_ops
.
insert
(
"dot"
);
unsupported_fp8_ops
.
insert
(
"quant_dot"
);
}
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops
.
insert
(
"pooling"
);
...
...
src/targets/ref/CMakeLists.txt
View file @
cac6c759
...
...
@@ -25,18 +25,13 @@
add_library
(
migraphx_ref
target.cpp
lowering.cpp
gemm.cpp
)
set_target_properties
(
migraphx_ref PROPERTIES EXPORT_NAME ref
)
rocm_set_soversion
(
migraphx_ref
${
MIGRAPHX_SO_VERSION
}
)
find_path
(
BLAZE_INCLUDE blaze/Blaze.h
)
rocm_clang_tidy_check
(
migraphx_ref
)
target_link_libraries
(
migraphx_ref PRIVATE Threads::Threads
)
target_link_libraries
(
migraphx_ref PUBLIC migraphx
)
target_include_directories
(
migraphx_ref SYSTEM PRIVATE
${
BLAZE_INCLUDE
}
)
target_compile_definitions
(
migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS
)
migraphx_generate_export_header
(
migraphx_ref
)
...
...
src/targets/ref/gemm.cpp
deleted
100644 → 0
View file @
4bde67c4
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/ref/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/par_for.hpp>
#include <blaze/math/CustomMatrix.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
ref
{
template
<
class
T
>
using
matrix
=
blaze
::
CustomMatrix
<
T
,
blaze
::
unaligned
,
blaze
::
unpadded
>
;
// NOLINT
template
<
class
T
>
static
auto
make_mat
(
tensor_view
<
T
>
x
)
{
const
auto
&
s
=
x
.
get_shape
();
// assert(s.lens().size() == 2);
std
::
size_t
n_dims
=
s
.
lens
().
size
();
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_1
=
n_dims
-
1
;
if
(
s
.
transposed
())
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
dim_1
],
s
.
lens
()[
dim_0
],
s
.
strides
()[
dim_1
]};
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
dim_0
],
s
.
lens
()[
dim_1
],
s
.
strides
()[
dim_0
]};
}
template
<
class
T
,
class
F
>
static
void
visit_mat
(
tensor_view
<
T
>
x
,
F
f
)
{
auto
mat
=
make_mat
(
x
);
if
(
x
.
get_shape
().
transposed
())
f
(
blaze
::
trans
(
mat
));
else
f
(
mat
);
}
template
<
class
T
>
struct
is_fast_gemm_type
:
std
::
false_type
{
};
template
<
>
struct
is_fast_gemm_type
<
float
>
:
std
::
true_type
{
};
template
<
class
T
,
class
F
>
void
migemm_impl
(
tensor_view
<
T
>
cmat
,
tensor_view
<
T
>
amat
,
tensor_view
<
T
>
bmat
,
F
alpha
,
F
beta
,
std
::
true_type
)
{
visit_mat
(
amat
,
[
&
](
const
auto
&
a
)
{
visit_mat
(
bmat
,
[
&
](
const
auto
&
b
)
{
auto
c
=
make_mat
(
cmat
);
c
=
beta
*
c
;
// This is a simple optimization to avoid
// compute A * B if alpha is 0.0
if
(
alpha
!=
0.0
)
{
c
=
c
+
alpha
*
a
*
b
;
}
});
});
}
template
<
class
T
,
class
F
>
void
migemm_impl
(
tensor_view
<
T
>
cmat
,
tensor_view
<
T
>
amat
,
tensor_view
<
T
>
bmat
,
F
alpha
,
F
beta
,
std
::
false_type
)
{
std
::
size_t
n_dims
=
cmat
.
get_shape
().
lens
().
size
();
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_1
=
n_dims
-
1
;
auto
k
=
amat
.
get_shape
().
lens
()[
dim_1
];
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_0
]
==
amat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_1
]);
auto
cs
=
cmat
.
get_shape
();
par_for
(
cs
.
elements
(),
[
&
](
auto
i
)
{
auto
c_idx
=
cs
.
multi
(
i
);
auto
a_idx
=
c_idx
;
auto
b_idx
=
c_idx
;
double
s
=
0.0
;
dfor
(
k
)([
&
](
auto
kk
)
{
a_idx
[
dim_1
]
=
b_idx
[
dim_0
]
=
kk
;
s
+=
amat
(
a_idx
.
begin
(),
a_idx
.
end
())
*
bmat
(
b_idx
.
begin
(),
b_idx
.
end
());
});
cmat
(
c_idx
.
begin
(),
c_idx
.
end
())
=
alpha
*
s
+
cmat
(
c_idx
.
begin
(),
c_idx
.
end
())
*
beta
;
});
}
template
<
class
T
,
class
F
>
void
migemm_impl
(
tensor_view
<
T
>
cmat
,
tensor_view
<
T
>
amat
,
tensor_view
<
T
>
bmat
,
F
alpha
,
F
beta
)
{
auto
lens
=
amat
.
get_shape
().
lens
();
bool
batch_mul
=
std
::
accumulate
(
lens
.
rbegin
()
+
2
,
lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
1
;
if
(
batch_mul
)
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
}
else
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
std
::
false_type
{});
}
}
template
<
class
F
>
void
migemm_tpl
(
const
argument
&
c_arg
,
const
argument
&
a_arg
,
const
argument
&
b_arg
,
F
alpha
,
F
beta
)
{
visit_all
(
c_arg
,
a_arg
,
b_arg
)(
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
);
});
}
void
migemm
(
const
argument
&
c_arg
,
const
argument
&
a_arg
,
const
argument
&
b_arg
,
float
alpha
,
float
beta
)
{
migemm_tpl
(
c_arg
,
a_arg
,
b_arg
,
alpha
,
beta
);
}
void
migemm
(
const
argument
&
c_arg
,
const
argument
&
a_arg
,
const
argument
&
b_arg
,
int32_t
alpha
,
int32_t
beta
)
{
migemm_tpl
(
c_arg
,
a_arg
,
b_arg
,
alpha
,
beta
);
}
}
// namespace ref
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/ref/include/migraphx/ref/gemm.hpp
deleted
100644 → 0
View file @
4bde67c4
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_CPU_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_CPU_GEMM_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
ref
{
void
migemm
(
const
argument
&
c_arg
,
const
argument
&
a_arg
,
const
argument
&
b_arg
,
float
alpha
,
float
beta
);
void
migemm
(
const
argument
&
c_arg
,
const
argument
&
a_arg
,
const
argument
&
b_arg
,
int32_t
alpha
,
int32_t
beta
);
}
// namespace ref
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/ref/lowering.cpp
View file @
cac6c759
...
...
@@ -44,7 +44,6 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/ref/gemm.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
...
...
@@ -283,8 +282,8 @@ struct ref_gemm
argument
compute
(
context
&
,
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
dyn_out
.
computed_shape
};
migemm
(
result
,
args
[
0
],
args
[
1
]
,
1.0
f
,
0.0
f
);
visit_all
(
result
,
args
[
0
],
args
[
1
]
)(
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
return
result
;
}
};
...
...
@@ -306,24 +305,14 @@ struct ref_quant_gemm
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
// first, convert the args[0] and args[1] from int8_t to int32_t
argument
arg_0
{{
shape
::
int32_type
,
{
args
.
at
(
0
).
get_shape
().
lens
()}}};
argument
arg_1
{{
shape
::
int32_type
,
{
args
.
at
(
1
).
get_shape
().
lens
()}}};
arg_0
.
visit
([
&
](
auto
output
)
{
args
.
at
(
0
).
visit
(
[
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
});
arg_1
.
visit
([
&
](
auto
output
)
{
args
.
at
(
1
).
visit
(
[
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
result
.
visit
([
&
](
auto
cmat
)
{
visit_all
(
args
.
at
(
0
),
args
.
at
(
1
))(
[
&
](
auto
amat
,
auto
bmat
)
{
return
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
});
migemm
(
result
,
arg_0
,
arg_1
,
int32_t
{
1
},
int32_t
{
0
});
return
result
;
}
};
MIGRAPHX_REGISTER_OP
(
ref_gemm
)
template
<
class
Op
>
...
...
test/onnx/.onnxrt-commit
View file @
cac6c759
d69842226b47e5336568103541b071447caeb9bf
44b58437402b207c8216f3be8c75accb7409be1c
test/onnx/dynamicquantizelinear_1d_test.onnx
0 → 100644
View file @
cac6c759
dynamicquantizelinear_1d_test:
4
xyy_scaley_zero_point"DynamicQuantizeLineardynamicquantizelinear_1d_testZ
x
b
y
b
y_scale
b
y_zero_point
B
\ No newline at end of file
test/onnx/dynamicquantizelinear_2d_test.onnx
0 → 100644
View file @
cac6c759
dynamicquantizelinear_2d_test:
4
xyy_scaley_zero_point"DynamicQuantizeLineardynamicquantizelinear_2d_testZ
x
b
y
b
y_scale
b
y_zero_point
B
\ No newline at end of file
test/onnx/gen_onnx.py
View file @
cac6c759
...
...
@@ -1968,6 +1968,40 @@ def dropout_test():
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
dynamicquantizelinear_1d_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
6
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
UINT8
,
[
6
])
y_scale
=
helper
.
make_tensor_value_info
(
'y_scale'
,
TensorProto
.
FLOAT
,
[
1
])
y_zero_point
=
helper
.
make_tensor_value_info
(
'y_zero_point'
,
TensorProto
.
UINT8
,
[
1
])
node
=
onnx
.
helper
.
make_node
(
'DynamicQuantizeLinear'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
,
'y_scale'
,
'y_zero_point'
],
)
return
([
node
],
[
x
],
[
y
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
dynamicquantizelinear_2d_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
3
,
4
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
UINT8
,
[
3
,
4
])
y_scale
=
helper
.
make_tensor_value_info
(
'y_scale'
,
TensorProto
.
FLOAT
,
[
1
])
y_zero_point
=
helper
.
make_tensor_value_info
(
'y_zero_point'
,
TensorProto
.
UINT8
,
[
1
])
node
=
onnx
.
helper
.
make_node
(
'DynamicQuantizeLinear'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
,
'y_scale'
,
'y_zero_point'
],
)
return
([
node
],
[
x
],
[
y
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
elu_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
3
])
...
...
@@ -6251,6 +6285,56 @@ def qlinearaveragepool_nt_cip_test():
return
([
node
],
[
x
],
[
y
],
[
x_scale
,
x_zero_point
,
y_scale
,
y_zero_point
])
@
onnx_test
()
def
qlinearconcat_test
():
y_scale
=
helper
.
make_tensor
(
'1'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
y_zero_point
=
helper
.
make_tensor
(
'2'
,
TensorProto
.
INT8
,
[],
[
2
])
t0
=
helper
.
make_tensor_value_info
(
't0'
,
TensorProto
.
INT8
,
[
2
])
s0
=
helper
.
make_tensor
(
'3'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
zp0
=
helper
.
make_tensor
(
'4'
,
TensorProto
.
INT8
,
[],
[
1
])
t1
=
helper
.
make_tensor_value_info
(
't1'
,
TensorProto
.
INT8
,
[
3
])
s1
=
helper
.
make_tensor
(
'5'
,
TensorProto
.
FLOAT
,
[],
[
0.25
])
zp1
=
helper
.
make_tensor
(
'6'
,
TensorProto
.
INT8
,
[],
[
0
])
y
=
helper
.
make_tensor_value_info
(
'out'
,
TensorProto
.
INT8
,
[
5
])
node
=
onnx
.
helper
.
make_node
(
'QLinearConcat'
,
inputs
=
[
'1'
,
'2'
,
't0'
,
'3'
,
'4'
,
't1'
,
'5'
,
'6'
],
axis
=
0
,
outputs
=
[
'out'
],
)
return
([
node
],
[
t0
,
t1
],
[
y
],
[
y_scale
,
y_zero_point
,
s0
,
zp0
,
s1
,
zp1
])
@
onnx_test
()
def
qlinearconcat_3d_test
():
y_scale
=
helper
.
make_tensor
(
'1'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
y_zero_point
=
helper
.
make_tensor
(
'2'
,
TensorProto
.
INT8
,
[],
[
2
])
t0
=
helper
.
make_tensor_value_info
(
't0'
,
TensorProto
.
INT8
,
[
3
,
4
,
2
])
s0
=
helper
.
make_tensor
(
'3'
,
TensorProto
.
FLOAT
,
[],
[
0.5
])
zp0
=
helper
.
make_tensor
(
'4'
,
TensorProto
.
INT8
,
[],
[
10
])
t1
=
helper
.
make_tensor_value_info
(
't1'
,
TensorProto
.
INT8
,
[
3
,
2
,
2
])
s1
=
helper
.
make_tensor
(
'5'
,
TensorProto
.
FLOAT
,
[],
[
0.4
])
zp1
=
helper
.
make_tensor
(
'6'
,
TensorProto
.
INT8
,
[],
[
20
])
y
=
helper
.
make_tensor_value_info
(
'out'
,
TensorProto
.
UINT8
,
[
3
,
6
,
2
])
node
=
onnx
.
helper
.
make_node
(
'QLinearConcat'
,
inputs
=
[
'1'
,
'2'
,
't0'
,
'3'
,
'4'
,
't1'
,
'5'
,
'6'
],
axis
=
1
,
outputs
=
[
'out'
],
)
return
([
node
],
[
t0
,
t1
],
[
y
],
[
y_scale
,
y_zero_point
,
s0
,
zp0
,
s1
,
zp1
])
@
onnx_test
()
def
qlinearconv_test
():
# https://xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__QLinearConv.html
...
...
test/onnx/onnx_test.cpp
View file @
cac6c759
...
...
@@ -1865,6 +1865,50 @@ TEST_CASE(depthtospace_simple_test)
EXPECT(p == prog);
}
TEST_CASE(dynamicquantizelinear_2d_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x_dims = {3, 4};
auto x_type = migraphx::shape::float_type;
auto x = mm->add_parameter("x", {x_type, x_dims});
auto l0 = mm->add_literal({0.f});
auto x_reshaped = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), x);
x_reshaped = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x_reshaped, l0);
auto q_range = mm->add_literal(
migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits<uint8_t>::max()}});
auto max_x = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), x_reshaped);
auto min_x = mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", {0}}}), x_reshaped);
auto sub0 = mm->add_instruction(migraphx::make_op("sub"), max_x, min_x);
auto y_scale = mm->add_instruction(migraphx::make_op("div"), sub0, q_range);
auto q_min = mm->add_literal(
migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits<uint8_t>::min()}});
auto q_max = mm->add_literal(
migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits<uint8_t>::max()}});
auto sub1 = mm->add_instruction(migraphx::make_op("sub"), q_min, min_x);
auto interm_zp = mm->add_instruction(migraphx::make_op("div"), sub1, y_scale);
auto saturate = mm->add_instruction(migraphx::make_op("clip"), interm_zp, q_min, q_max);
auto round = mm->add_instruction(migraphx::make_op("nearbyint"), saturate);
auto y_zero_point = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::uint8_type}}), round);
auto scale_y_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", x_dims}}), y_scale);
auto y_pt_c_bcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_dims}}), y_zero_point);
mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_y_bcast, y_pt_c_bcast);
auto prog = optimize_onnx("dynamicquantizelinear_2d_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(spacetodepth_test)
{
migraphx::program p;
...
...
@@ -2863,12 +2907,12 @@ migraphx::program make_group_norm(const std::vector<int64_t>& input_dims,
auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}});
auto x_reshaped =
auto x_reshaped
d
=
mm->add_instruction(migraphx::make_op("reshape", {{"dims", reshape_dims}}), x);
auto mean =
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x_reshaped);
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x_reshaped, mean});
auto x_sqdiff_mean = add_common_op(*mm, migraphx::make_op("sqdiff"), {x_reshaped, mean});
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x_reshaped
d
);
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x_reshaped
d
, mean});
auto x_sqdiff_mean = add_common_op(*mm, migraphx::make_op("sqdiff"), {x_reshaped
d
, mean});
auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}),
x_sqdiff_mean);
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
...
...
@@ -5645,6 +5689,59 @@ TEST_CASE(qlinearaveragepool_notset_test)
EXPECT(p == prog);
}
TEST_CASE(qlinearconcat_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto sc_y = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto z_pt_y = mm->add_literal(migraphx::literal{migraphx::shape::int8_type, {2}});
auto sc_0 = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto z_pt_0 = mm->add_literal(migraphx::literal{migraphx::shape::int8_type, {1}});
auto sc_1 = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.25}});
auto z_pt_1 = mm->add_literal(migraphx::literal{migraphx::shape::int8_type, {0}});
auto t0 = mm->add_parameter("t0", {migraphx::shape::int8_type, {2}});
auto t1 = mm->add_parameter("t1", {migraphx::shape::int8_type, {3}});
auto scale_0_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2}}}), sc_0);
auto z_pt_0_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2}}}), z_pt_0);
auto fp_0 =
mm->add_instruction(migraphx::make_op("dequantizelinear"), t0, scale_0_bcast, z_pt_0_bcast);
auto scale_1_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), sc_1);
auto z_pt_1_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), z_pt_1);
auto fp_1 =
mm->add_instruction(migraphx::make_op("dequantizelinear"), t1, scale_1_bcast, z_pt_1_bcast);
auto fp_y = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), fp_0, fp_1);
auto scale_y_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), sc_y);
auto z_pt_y_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), z_pt_y);
auto y =
mm->add_instruction(migraphx::make_op("quantizelinear"), fp_y, scale_y_bcast, z_pt_y_bcast);
mm->add_return({y});
auto prog = migraphx::parse_onnx("qlinearconcat_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(qlinearconv_test)
{
migraphx::program p;
...
...
test/onnx/qlinearconcat_3d_test.onnx
0 → 100644
View file @
cac6c759
File added
test/onnx/qlinearconcat_test.onnx
0 → 100644
View file @
cac6c759
File added
test/onnx/verify_onnx.cpp
View file @
cac6c759
...
...
@@ -351,6 +351,87 @@ TEST_CASE(depthtospace_simple_test)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
dynamicquantizelinear_1d_test
)
{
auto
p
=
migraphx
::
parse_onnx
(
"dynamicquantizelinear_1d_test.onnx"
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
float
>
data
{
0
,
2
,
-
3
,
-
2.5
,
1.34
,
0.5
};
migraphx
::
shape
s_x
{
migraphx
::
shape
::
float_type
,
{
6
}};
migraphx
::
parameter_map
pp
;
pp
[
"x"
]
=
migraphx
::
argument
(
s_x
,
data
.
data
());
auto
results
=
p
.
eval
(
pp
);
std
::
vector
<
uint8_t
>
y_results
;
results
.
at
(
0
).
visit
([
&
](
auto
output
)
{
y_results
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
uint8_t
>
y_gold
=
{
153
,
255
,
0
,
26
,
221
,
179
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
y_results
,
y_gold
));
std
::
vector
<
float
>
y_scale
;
results
.
at
(
1
).
visit
([
&
](
auto
output
)
{
y_scale
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
y_scale_gold
=
{
0.0196078438
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
y_scale
,
y_scale_gold
));
std
::
vector
<
uint8_t
>
y_zpt
;
results
.
at
(
2
).
visit
([
&
](
auto
output
)
{
y_zpt
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
uint8_t
>
y_zpt_gold
=
{
153
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
y_zpt
,
y_zpt_gold
));
}
TEST_CASE
(
dynamicquantizelinear_1d_max_adjusted_test
)
{
auto
p
=
migraphx
::
parse_onnx
(
"dynamicquantizelinear_1d_test.onnx"
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
float
>
data
{
-
1.0
,
-
2.1
,
-
1.3
,
-
2.5
,
-
3.34
,
-
4.0
};
migraphx
::
shape
s_x
{
migraphx
::
shape
::
float_type
,
{
6
}};
migraphx
::
parameter_map
pp
;
pp
[
"x"
]
=
migraphx
::
argument
(
s_x
,
data
.
data
());
auto
results
=
p
.
eval
(
pp
);
std
::
vector
<
uint8_t
>
y_results
;
results
.
at
(
0
).
visit
([
&
](
auto
output
)
{
y_results
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
uint8_t
>
y_gold
=
{
191
,
121
,
172
,
96
,
42
,
0
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
y_results
,
y_gold
));
std
::
vector
<
float
>
y_scale
;
results
.
at
(
1
).
visit
([
&
](
auto
output
)
{
y_scale
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
y_scale_gold
=
{
0.0156862754
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
y_scale
,
y_scale_gold
));
std
::
vector
<
uint8_t
>
y_zpt
;
results
.
at
(
2
).
visit
([
&
](
auto
output
)
{
y_zpt
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
uint8_t
>
y_zpt_gold
=
{
255
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
y_zpt
,
y_zpt_gold
));
}
TEST_CASE
(
dynamicquantizelinear_2d_test
)
{
auto
p
=
migraphx
::
parse_onnx
(
"dynamicquantizelinear_2d_test.onnx"
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
float
>
data
{
1.0
,
2.1
,
1.3
,
2.5
,
3.34
,
4.0
,
1.5
,
2.6
,
3.9
,
4.0
,
3.0
,
2.345
};
migraphx
::
shape
s_x
{
migraphx
::
shape
::
float_type
,
{
3
,
4
}};
migraphx
::
parameter_map
pp
;
pp
[
"x"
]
=
migraphx
::
argument
(
s_x
,
data
.
data
());
auto
results
=
p
.
eval
(
pp
);
std
::
vector
<
uint8_t
>
y_results
;
results
.
at
(
0
).
visit
([
&
](
auto
output
)
{
y_results
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
uint8_t
>
y_gold
=
{
64
,
134
,
83
,
159
,
213
,
255
,
96
,
166
,
249
,
255
,
191
,
149
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
y_results
,
y_gold
));
std
::
vector
<
float
>
y_scale
;
results
.
at
(
1
).
visit
([
&
](
auto
output
)
{
y_scale
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
y_scale_gold
=
{
0.0156862754
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
y_scale
,
y_scale_gold
));
std
::
vector
<
uint8_t
>
y_zpt
;
results
.
at
(
2
).
visit
([
&
](
auto
output
)
{
y_zpt
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
uint8_t
>
y_zpt_gold
=
{
0
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
y_zpt
,
y_zpt_gold
));
}
TEST_CASE
(
spacetodepth_simple_test
)
{
auto
p
=
migraphx
::
parse_onnx
(
"spacetodepth_simple_test.onnx"
);
...
...
@@ -1932,6 +2013,52 @@ TEST_CASE(qlinearaveragepool_nt_cip_test)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
qlinearconcat_test
)
{
auto
p
=
migraphx
::
parse_onnx
(
"qlinearconcat_test.onnx"
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
int8_t
>
data_t0
=
{
2
,
3
};
migraphx
::
shape
s_t0
{
migraphx
::
shape
::
int8_type
,
{
2
}};
migraphx
::
parameter_map
pp
;
pp
[
"t0"
]
=
migraphx
::
argument
(
s_t0
,
data_t0
.
data
());
std
::
vector
<
int8_t
>
data_t1
=
{
6
,
8
,
10
};
migraphx
::
shape
s_t1
{
migraphx
::
shape
::
int8_type
,
{
3
}};
pp
[
"t1"
]
=
migraphx
::
argument
(
s_t1
,
data_t1
.
data
());
auto
result
=
p
.
eval
(
pp
).
back
();
std
::
vector
<
int8_t
>
result_vector
;
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int8_t
>
gold
=
{
3
,
4
,
5
,
6
,
7
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
qlinearconcat_3d_test
)
{
auto
p
=
migraphx
::
parse_onnx
(
"qlinearconcat_3d_test.onnx"
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
std
::
vector
<
int8_t
>
data_t0
=
{
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
,
10
};
migraphx
::
shape
s_t0
{
migraphx
::
shape
::
int8_type
,
{
3
,
4
,
2
}};
migraphx
::
parameter_map
pp
;
pp
[
"t0"
]
=
migraphx
::
argument
(
s_t0
,
data_t0
.
data
());
std
::
vector
<
int8_t
>
data_t1
=
{
25
,
25
,
25
,
25
,
25
,
25
,
25
,
25
,
25
,
25
,
25
,
25
};
migraphx
::
shape
s_t1
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
2
}};
pp
[
"t1"
]
=
migraphx
::
argument
(
s_t1
,
data_t1
.
data
());
auto
result
=
p
.
eval
(
pp
).
back
();
std
::
vector
<
uint8_t
>
result_vector
;
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int8_t
>
gold
=
{
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
6
,
6
,
6
,
6
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
6
,
6
,
6
,
6
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
6
,
6
,
6
,
6
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
qlinearconv_test
)
{
// https://xadupre.github.io/draft/onnx/onnx_doc_folder/onnx__QLinearConv.html
...
...
test/op_shape_test.cpp
View file @
cac6c759
...
...
@@ -2682,36 +2682,26 @@ TEST_CASE(reshape_shape_minus1_reshapes)
}
}
// This uses the permutation to compute the reshape since its simpler than
// trying to calculate strides. As we collapse or expand dimensions, we
// remove the collapsed dimensions or duplicate the expanded dimensions in
// the permutation. Then we renumber the permutation. So for dimensions of 4,
// 24, 1, 1, 1 with a permutation of 1, 0, 2, 3, 4 that reshapes to 4, 1, 3,
// 4, 2, we first remove the collapsed dimensions or duplicate the expanded
// dimensions which gives 1, 0, 0, 0, 0. Then after renumbering we get a
// final permutation of 4, 0, 1, 2, 3.
TEST_CASE
(
reshape_nonstandard
)
{
auto
input
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
{
4
,
24
,
1
,
1
,
1
},
migraphx
::
invert_permutation
({
1
,
0
,
2
,
3
,
4
}));
std
::
vector
<
std
::
pair
<
std
::
vector
<
std
::
size_t
>
,
std
::
vector
<
int64_t
>>>
tests
{
{{
4
,
24
},
{
1
,
0
}},
{{
4
,
24
,
1
,
1
,
1
,
1
},
{
1
,
0
,
2
,
3
,
4
,
5
}},
{{
4
,
8
,
3
,
1
,
1
},
{
2
,
0
,
1
,
3
,
4
}},
{{
4
,
1
,
3
,
4
,
2
},
{
4
,
0
,
1
,
2
,
3
}},
{{
4
,
1
,
4
,
3
,
2
},
{
4
,
0
,
1
,
2
,
3
}},
{{
4
,
2
,
4
,
3
},
{
3
,
0
,
1
,
2
}},
{{
4
,
2
,
12
,
1
},
{
2
,
0
,
1
,
3
}},
{{
4
,
2
,
1
,
12
},
{
3
,
0
,
1
,
2
}},
{{
4
,
4
,
2
,
3
},
{
3
,
0
,
1
,
2
}},
{{
4
,
8
,
1
,
3
},
{
3
,
0
,
1
,
2
}},
{{
4
,
8
,
3
,
1
},
{
2
,
0
,
1
,
3
}}};
for
(
const
auto
&
[
dims
,
perm
]
:
tests
)
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
tests
{{
4
,
24
},
{
4
,
24
,
1
,
1
,
1
,
1
},
{
4
,
8
,
3
,
1
,
1
},
{
4
,
1
,
3
,
4
,
2
},
{
4
,
1
,
4
,
3
,
2
},
{
4
,
2
,
4
,
3
},
{
4
,
2
,
12
,
1
},
{
4
,
2
,
1
,
12
},
{
4
,
4
,
2
,
3
},
{
4
,
8
,
1
,
3
},
{
4
,
8
,
3
,
1
}};
for
(
auto
dims
:
tests
)
{
migraphx
::
shape
output
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
dims
,
migraphx
::
invert_permutation
(
perm
));
migraphx
::
shape
output
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
dims
};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
input
);
}
}
...
...
@@ -2721,8 +2711,7 @@ TEST_CASE(reshape_nonstandard_squeeze)
auto
input
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
migraphx
::
invert_permutation
({
0
,
2
,
3
,
1
}));
std
::
vector
<
std
::
size_t
>
lens
=
{
2
,
256
,
1280
};
migraphx
::
shape
output
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
lens
,
migraphx
::
invert_permutation
({
0
,
2
,
1
}));
migraphx
::
shape
output
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
lens
};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
}
...
...
@@ -2746,52 +2735,80 @@ TEST_CASE(reshape_nonstandard_error)
}
}
TEST_CASE
(
reshape_transposed_squeeze
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
1
,
4
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_nonpacked_unsqueeze1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
8
}
,
{
32
,
16
,
2
}
};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
8
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_nonpacked_unsqueeze2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
16
}
,
{
64
,
32
,
2
}
};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
16
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_nonpacked_squeeze
)
TEST_CASE
(
reshape_nonpacked_squeeze
1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
},
{
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_nonpacked_squeeze2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_broadcast_unsqueeze1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
}
,
{
0
,
0
,
0
,
1
}
};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_broadcast_unsqueeze2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
16
,
80
}
,
{
0
,
0
,
80
,
1
}
};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
16
,
80
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_broadcast_squeeze
)
TEST_CASE
(
reshape_broadcast_squeeze
1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_broadcast_squeeze2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_broadcast_squeeze3
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
1
,
0
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_broadcast_squeeze_memlayout_change
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
256
,
80
}
,
{
0
,
0
,
0
,
16
}
};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
256
,
80
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
...
...
@@ -2960,6 +2977,12 @@ TEST_CASE(reshape_lazy_nonstandard_error)
}
}
TEST_CASE
(
reshape_lazy_transposed_squeeze
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
1
,
4
}};
throws_shape
(
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
{
64
}}}),
input
);
}
TEST_CASE
(
reshape_lazy_nonpacked_unsqueeze1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
...
...
@@ -2974,13 +2997,19 @@ TEST_CASE(reshape_lazy_nonpacked_unsqueeze2)
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_lazy_nonpacked_squeeze
)
TEST_CASE
(
reshape_lazy_nonpacked_squeeze
1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
},
{
2
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_lazy_nonpacked_squeeze2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
1
}};
throws_shape
(
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
{
64
}}}),
input
);
}
TEST_CASE
(
reshape_lazy_broadcast_unsqueeze1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
...
...
@@ -2995,13 +3024,25 @@ TEST_CASE(reshape_lazy_broadcast_unsqueeze2)
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_lazy_broadcast_squeeze
)
TEST_CASE
(
reshape_lazy_broadcast_squeeze
1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_lazy_broadcast_squeeze2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
0
,
1
}};
throws_shape
(
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
{
64
}}}),
input
);
}
TEST_CASE
(
reshape_lazy_broadcast_squeeze3
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
1
,
0
}};
throws_shape
(
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
{
64
}}}),
input
);
}
TEST_CASE
(
reshape_lazy_broadcast_squeeze_error
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment